cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/offload_utils.py +115 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/caching/cache_interface.py +358 -0
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
- cache_dit/caching/utils.py +68 -0
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/metrics/metrics.py +3 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- cache_dit-1.0.14.dist-info/METADATA +301 -0
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
- cache_dit/cache_factory/cache_blocks/utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
- cache_dit/cache_factory/cache_interface.py +0 -217
- cache_dit/cache_factory/patch_functors/__init__.py +0 -12
- cache_dit/cache_factory/utils.py +0 -57
- cache_dit-0.3.2.dist-info/METADATA +0 -753
- cache_dit-0.3.2.dist-info/RECORD +0 -56
- cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import functools
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
6
|
+
from diffusers.models.transformers.transformer_ltx import (
|
|
7
|
+
LTXVideoTransformer3DModel,
|
|
8
|
+
LTXAttention,
|
|
9
|
+
AttentionModuleMixin,
|
|
10
|
+
LTXVideoAttnProcessor,
|
|
11
|
+
apply_rotary_emb,
|
|
12
|
+
)
|
|
13
|
+
from diffusers.models.attention_dispatch import dispatch_attention_fn
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from diffusers.models._modeling_parallel import (
|
|
17
|
+
ContextParallelInput,
|
|
18
|
+
ContextParallelOutput,
|
|
19
|
+
ContextParallelModelPlan,
|
|
20
|
+
)
|
|
21
|
+
except ImportError:
|
|
22
|
+
raise ImportError(
|
|
23
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
24
|
+
"Please install latest version of diffusers from source: \n"
|
|
25
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
26
|
+
)
|
|
27
|
+
from .cp_plan_registers import (
|
|
28
|
+
ContextParallelismPlanner,
|
|
29
|
+
ContextParallelismPlannerRegister,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
from cache_dit.logger import init_logger
|
|
33
|
+
|
|
34
|
+
logger = init_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@ContextParallelismPlannerRegister.register("LTXVideo")
|
|
38
|
+
class LTXVideoContextParallelismPlanner(ContextParallelismPlanner):
|
|
39
|
+
def apply(
|
|
40
|
+
self,
|
|
41
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
42
|
+
**kwargs,
|
|
43
|
+
) -> ContextParallelModelPlan:
|
|
44
|
+
assert transformer is not None, "Transformer must be provided."
|
|
45
|
+
assert isinstance(
|
|
46
|
+
transformer, LTXVideoTransformer3DModel
|
|
47
|
+
), "Transformer must be an instance of LTXVideoTransformer3DModel"
|
|
48
|
+
|
|
49
|
+
# NOTE: The atttention_mask preparation in LTXAttention while using
|
|
50
|
+
# context parallelism is buggy in diffusers v0.36.0.dev0, so we
|
|
51
|
+
# disable the preference to use native diffusers implementation here.
|
|
52
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
53
|
+
|
|
54
|
+
if (
|
|
55
|
+
transformer is not None
|
|
56
|
+
and self._cp_planner_preferred_native_diffusers
|
|
57
|
+
):
|
|
58
|
+
if hasattr(transformer, "_cp_plan"):
|
|
59
|
+
if transformer._cp_plan is not None:
|
|
60
|
+
return transformer._cp_plan
|
|
61
|
+
|
|
62
|
+
# Apply monkey patch to fix attention mask preparation at class level
|
|
63
|
+
assert issubclass(LTXAttention, AttentionModuleMixin)
|
|
64
|
+
LTXAttention.prepare_attention_mask = (
|
|
65
|
+
__patch__LTXAttention_prepare_attention_mask__
|
|
66
|
+
)
|
|
67
|
+
LTXVideoAttnProcessor.__call__ = __patch__LTXVideoAttnProcessor__call__
|
|
68
|
+
|
|
69
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
70
|
+
# a little different from the native diffusers implementation
|
|
71
|
+
# for some models.
|
|
72
|
+
|
|
73
|
+
_cp_plan = {
|
|
74
|
+
# Here is a Transformer level CP plan for Flux, which will
|
|
75
|
+
# only apply the only 1 split hook (pre_forward) on the forward
|
|
76
|
+
# of Transformer, and gather the output after Transformer forward.
|
|
77
|
+
# Pattern of transformer forward, split_output=False:
|
|
78
|
+
# un-split input -> splited input (inside transformer)
|
|
79
|
+
# Pattern of the transformer_blocks, single_transformer_blocks:
|
|
80
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
81
|
+
# -> all2all
|
|
82
|
+
# -> attn (local head, full seqlen)
|
|
83
|
+
# -> all2all
|
|
84
|
+
# -> splited output
|
|
85
|
+
# The `hidden_states` and `encoder_hidden_states` will still keep
|
|
86
|
+
# itself splited after block forward, namely, hidden_states will
|
|
87
|
+
# automatically split by the all2all comm op after attn, and the
|
|
88
|
+
# encoder_hidden_states will be keep splited after the entrypoint
|
|
89
|
+
# of transformer forward, for the all blocks.
|
|
90
|
+
"": {
|
|
91
|
+
"hidden_states": ContextParallelInput(
|
|
92
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
93
|
+
),
|
|
94
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
95
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
96
|
+
),
|
|
97
|
+
# NOTE: encoder_attention_mask (namely, attention_mask in cross-attn)
|
|
98
|
+
# should never be split across seqlen while using context parallelism
|
|
99
|
+
# for LTXVideoTransformer3DModel. It don't contribute to any computation
|
|
100
|
+
# in parallel or not. So we comment it out here and handle the head-split
|
|
101
|
+
# correctly while using context parallel in the patched attention processor.
|
|
102
|
+
# "encoder_attention_mask": ContextParallelInput(
|
|
103
|
+
# split_dim=1, expected_dims=2, split_output=False
|
|
104
|
+
# ),
|
|
105
|
+
},
|
|
106
|
+
# Pattern of rope, split_output=True (split output rather than input):
|
|
107
|
+
# un-split input
|
|
108
|
+
# -> keep input un-split
|
|
109
|
+
# -> rope
|
|
110
|
+
# -> splited output
|
|
111
|
+
"rope": {
|
|
112
|
+
0: ContextParallelInput(
|
|
113
|
+
split_dim=1, expected_dims=3, split_output=True
|
|
114
|
+
),
|
|
115
|
+
1: ContextParallelInput(
|
|
116
|
+
split_dim=1, expected_dims=3, split_output=True
|
|
117
|
+
),
|
|
118
|
+
},
|
|
119
|
+
# Then, the final proj_out will gather the splited output.
|
|
120
|
+
# splited input (previous splited output)
|
|
121
|
+
# -> all gather
|
|
122
|
+
# -> un-split output
|
|
123
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
124
|
+
}
|
|
125
|
+
return _cp_plan
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@functools.wraps(LTXAttention.prepare_attention_mask)
|
|
129
|
+
def __patch__LTXAttention_prepare_attention_mask__(
|
|
130
|
+
self: LTXAttention,
|
|
131
|
+
attention_mask: torch.Tensor,
|
|
132
|
+
target_length: int,
|
|
133
|
+
batch_size: int,
|
|
134
|
+
out_dim: int = 3,
|
|
135
|
+
# NOTE(DefTruth): Allow specifying head_size for CP
|
|
136
|
+
head_size: Optional[int] = None,
|
|
137
|
+
) -> torch.Tensor:
|
|
138
|
+
"""
|
|
139
|
+
Prepare the attention mask for the attention computation.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
attention_mask (`torch.Tensor`): The attention mask to prepare.
|
|
143
|
+
target_length (`int`): The target length of the attention mask.
|
|
144
|
+
batch_size (`int`): The batch size for repeating the attention mask.
|
|
145
|
+
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
`torch.Tensor`: The prepared attention mask.
|
|
149
|
+
"""
|
|
150
|
+
if head_size is None:
|
|
151
|
+
head_size = self.heads
|
|
152
|
+
if attention_mask is None:
|
|
153
|
+
return attention_mask
|
|
154
|
+
|
|
155
|
+
current_length: int = attention_mask.shape[-1]
|
|
156
|
+
if current_length != target_length:
|
|
157
|
+
if attention_mask.device.type == "mps":
|
|
158
|
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
|
159
|
+
# Instead, we can manually construct the padding tensor.
|
|
160
|
+
padding_shape = (
|
|
161
|
+
attention_mask.shape[0],
|
|
162
|
+
attention_mask.shape[1],
|
|
163
|
+
target_length,
|
|
164
|
+
)
|
|
165
|
+
padding = torch.zeros(
|
|
166
|
+
padding_shape,
|
|
167
|
+
dtype=attention_mask.dtype,
|
|
168
|
+
device=attention_mask.device,
|
|
169
|
+
)
|
|
170
|
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
|
171
|
+
else:
|
|
172
|
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
|
173
|
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
|
174
|
+
# remaining_length: int = target_length - current_length
|
|
175
|
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
|
176
|
+
attention_mask = F.pad(
|
|
177
|
+
attention_mask, (0, target_length), value=0.0
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if out_dim == 3:
|
|
181
|
+
if attention_mask.shape[0] < batch_size * head_size:
|
|
182
|
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
|
183
|
+
elif out_dim == 4:
|
|
184
|
+
attention_mask = attention_mask.unsqueeze(1)
|
|
185
|
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
|
186
|
+
|
|
187
|
+
return attention_mask
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@functools.wraps(LTXVideoAttnProcessor.__call__)
|
|
191
|
+
def __patch__LTXVideoAttnProcessor__call__(
|
|
192
|
+
self: LTXVideoAttnProcessor,
|
|
193
|
+
attn: "LTXAttention",
|
|
194
|
+
hidden_states: torch.Tensor,
|
|
195
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
196
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
197
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
198
|
+
) -> torch.Tensor:
|
|
199
|
+
batch_size, sequence_length, _ = (
|
|
200
|
+
hidden_states.shape
|
|
201
|
+
if encoder_hidden_states is None
|
|
202
|
+
else encoder_hidden_states.shape
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if attention_mask is not None:
|
|
206
|
+
if self._parallel_config is None:
|
|
207
|
+
attention_mask = attn.prepare_attention_mask(
|
|
208
|
+
attention_mask, sequence_length, batch_size
|
|
209
|
+
)
|
|
210
|
+
attention_mask = attention_mask.view(
|
|
211
|
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
# NOTE(DefTruth): Fix attention mask preparation for context parallelism
|
|
215
|
+
cp_config = getattr(
|
|
216
|
+
self._parallel_config, "context_parallel_config", None
|
|
217
|
+
)
|
|
218
|
+
if cp_config is not None and cp_config._world_size > 1:
|
|
219
|
+
head_size = attn.heads // cp_config._world_size
|
|
220
|
+
attention_mask = attn.prepare_attention_mask(
|
|
221
|
+
attention_mask,
|
|
222
|
+
sequence_length * cp_config._world_size,
|
|
223
|
+
batch_size,
|
|
224
|
+
3,
|
|
225
|
+
head_size,
|
|
226
|
+
)
|
|
227
|
+
attention_mask = attention_mask.view(
|
|
228
|
+
batch_size, head_size, -1, attention_mask.shape[-1]
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if encoder_hidden_states is None:
|
|
232
|
+
encoder_hidden_states = hidden_states
|
|
233
|
+
|
|
234
|
+
query = attn.to_q(hidden_states)
|
|
235
|
+
key = attn.to_k(encoder_hidden_states)
|
|
236
|
+
value = attn.to_v(encoder_hidden_states)
|
|
237
|
+
|
|
238
|
+
query = attn.norm_q(query)
|
|
239
|
+
key = attn.norm_k(key)
|
|
240
|
+
|
|
241
|
+
if image_rotary_emb is not None:
|
|
242
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
|
243
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
|
244
|
+
|
|
245
|
+
query = query.unflatten(2, (attn.heads, -1))
|
|
246
|
+
key = key.unflatten(2, (attn.heads, -1))
|
|
247
|
+
value = value.unflatten(2, (attn.heads, -1))
|
|
248
|
+
|
|
249
|
+
hidden_states = dispatch_attention_fn(
|
|
250
|
+
query,
|
|
251
|
+
key,
|
|
252
|
+
value,
|
|
253
|
+
attn_mask=attention_mask,
|
|
254
|
+
dropout_p=0.0,
|
|
255
|
+
is_causal=False,
|
|
256
|
+
backend=self._attention_backend,
|
|
257
|
+
parallel_config=self._parallel_config,
|
|
258
|
+
)
|
|
259
|
+
hidden_states = hidden_states.flatten(2, 3)
|
|
260
|
+
hidden_states = hidden_states.to(query.dtype)
|
|
261
|
+
|
|
262
|
+
hidden_states = attn.to_out[0](hidden_states)
|
|
263
|
+
hidden_states = attn.to_out[1](hidden_states)
|
|
264
|
+
return hidden_states
|
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import functools
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
5
|
+
from diffusers.models.attention_dispatch import dispatch_attention_fn
|
|
6
|
+
from diffusers.models.transformers.transformer_qwenimage import (
|
|
7
|
+
apply_rotary_emb_qwen,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from nunchaku.models.transformers.transformer_flux_v2 import (
|
|
12
|
+
NunchakuFluxAttention,
|
|
13
|
+
NunchakuFluxFA2Processor,
|
|
14
|
+
NunchakuFluxTransformer2DModelV2,
|
|
15
|
+
)
|
|
16
|
+
from nunchaku.ops.fused import fused_qkv_norm_rottary
|
|
17
|
+
from nunchaku.models.transformers.transformer_qwenimage import (
|
|
18
|
+
NunchakuQwenAttention,
|
|
19
|
+
NunchakuQwenImageNaiveFA2Processor,
|
|
20
|
+
NunchakuQwenImageTransformer2DModel,
|
|
21
|
+
)
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"NunchakuFluxTransformer2DModelV2 or NunchakuQwenImageTransformer2DModel "
|
|
25
|
+
"requires the 'nunchaku' package. Please install nunchaku before using "
|
|
26
|
+
"the context parallelism for nunchaku 4-bits models."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from diffusers.models._modeling_parallel import (
|
|
31
|
+
ContextParallelInput,
|
|
32
|
+
ContextParallelOutput,
|
|
33
|
+
ContextParallelModelPlan,
|
|
34
|
+
)
|
|
35
|
+
except ImportError:
|
|
36
|
+
raise ImportError(
|
|
37
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
38
|
+
"Please install latest version of diffusers from source: \n"
|
|
39
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
40
|
+
)
|
|
41
|
+
from .cp_plan_registers import (
|
|
42
|
+
ContextParallelismPlanner,
|
|
43
|
+
ContextParallelismPlannerRegister,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
from cache_dit.logger import init_logger
|
|
47
|
+
|
|
48
|
+
logger = init_logger(__name__)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@ContextParallelismPlannerRegister.register("NunchakuFlux")
|
|
52
|
+
class NunchakuFluxContextParallelismPlanner(ContextParallelismPlanner):
|
|
53
|
+
def apply(
|
|
54
|
+
self,
|
|
55
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
56
|
+
**kwargs,
|
|
57
|
+
) -> ContextParallelModelPlan:
|
|
58
|
+
|
|
59
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
60
|
+
|
|
61
|
+
if (
|
|
62
|
+
transformer is not None
|
|
63
|
+
and self._cp_planner_preferred_native_diffusers
|
|
64
|
+
):
|
|
65
|
+
|
|
66
|
+
assert isinstance(
|
|
67
|
+
transformer, NunchakuFluxTransformer2DModelV2
|
|
68
|
+
), "Transformer must be an instance of NunchakuFluxTransformer2DModelV2"
|
|
69
|
+
if hasattr(transformer, "_cp_plan"):
|
|
70
|
+
if transformer._cp_plan is not None:
|
|
71
|
+
return transformer._cp_plan
|
|
72
|
+
|
|
73
|
+
NunchakuFluxFA2Processor.__call__ = (
|
|
74
|
+
__patch_NunchakuFluxFA2Processor__call__
|
|
75
|
+
)
|
|
76
|
+
# Also need to patch the parallel config and attention backend
|
|
77
|
+
if not hasattr(NunchakuFluxFA2Processor, "_parallel_config"):
|
|
78
|
+
NunchakuFluxFA2Processor._parallel_config = None
|
|
79
|
+
if not hasattr(NunchakuFluxFA2Processor, "_attention_backend"):
|
|
80
|
+
NunchakuFluxFA2Processor._attention_backend = None
|
|
81
|
+
if not hasattr(NunchakuFluxAttention, "_parallel_config"):
|
|
82
|
+
NunchakuFluxAttention._parallel_config = None
|
|
83
|
+
if not hasattr(NunchakuFluxAttention, "_attention_backend"):
|
|
84
|
+
NunchakuFluxAttention._attention_backend = None
|
|
85
|
+
|
|
86
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
87
|
+
# a little different from the native diffusers implementation
|
|
88
|
+
# for some models.
|
|
89
|
+
_cp_plan = {
|
|
90
|
+
# Here is a Transformer level CP plan for Flux, which will
|
|
91
|
+
# only apply the only 1 split hook (pre_forward) on the forward
|
|
92
|
+
# of Transformer, and gather the output after Transformer forward.
|
|
93
|
+
# Pattern of transformer forward, split_output=False:
|
|
94
|
+
# un-split input -> splited input (inside transformer)
|
|
95
|
+
# Pattern of the transformer_blocks, single_transformer_blocks:
|
|
96
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
97
|
+
# -> all2all
|
|
98
|
+
# -> attn (local head, full seqlen)
|
|
99
|
+
# -> all2all
|
|
100
|
+
# -> splited output
|
|
101
|
+
# The `hidden_states` and `encoder_hidden_states` will still keep
|
|
102
|
+
# itself splited after block forward (namely, automatic split by
|
|
103
|
+
# the all2all comm op after attn) for the all blocks.
|
|
104
|
+
# img_ids and txt_ids will only be splited once at the very beginning,
|
|
105
|
+
# and keep splited through the whole transformer forward. The all2all
|
|
106
|
+
# comm op only happens on the `out` tensor after local attn not on
|
|
107
|
+
# img_ids and txt_ids.
|
|
108
|
+
"": {
|
|
109
|
+
"hidden_states": ContextParallelInput(
|
|
110
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
111
|
+
),
|
|
112
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
113
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
114
|
+
),
|
|
115
|
+
"img_ids": ContextParallelInput(
|
|
116
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
117
|
+
),
|
|
118
|
+
"txt_ids": ContextParallelInput(
|
|
119
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
120
|
+
),
|
|
121
|
+
},
|
|
122
|
+
# Then, the final proj_out will gather the splited output.
|
|
123
|
+
# splited input (previous splited output)
|
|
124
|
+
# -> all gather
|
|
125
|
+
# -> un-split output
|
|
126
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
127
|
+
}
|
|
128
|
+
return _cp_plan
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@functools.wraps(NunchakuFluxFA2Processor.__call__)
|
|
132
|
+
def __patch_NunchakuFluxFA2Processor__call__(
|
|
133
|
+
self: NunchakuFluxFA2Processor,
|
|
134
|
+
attn: NunchakuFluxAttention,
|
|
135
|
+
hidden_states: torch.Tensor,
|
|
136
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
137
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
138
|
+
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor = None,
|
|
139
|
+
**kwargs,
|
|
140
|
+
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
|
|
141
|
+
# The original implementation of NunchakuFluxFA2Processor.__call__
|
|
142
|
+
# is not changed here for brevity. In actual implementation, we need to
|
|
143
|
+
# modify the attention computation to support context parallelism.
|
|
144
|
+
if attention_mask is not None:
|
|
145
|
+
raise NotImplementedError("attention_mask is not supported")
|
|
146
|
+
|
|
147
|
+
batch_size, _, channels = hidden_states.shape
|
|
148
|
+
assert channels == attn.heads * attn.head_dim
|
|
149
|
+
qkv = fused_qkv_norm_rottary(
|
|
150
|
+
hidden_states,
|
|
151
|
+
attn.to_qkv,
|
|
152
|
+
attn.norm_q,
|
|
153
|
+
attn.norm_k,
|
|
154
|
+
(
|
|
155
|
+
image_rotary_emb[0]
|
|
156
|
+
if isinstance(image_rotary_emb, tuple)
|
|
157
|
+
else image_rotary_emb
|
|
158
|
+
),
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if attn.added_kv_proj_dim is not None:
|
|
162
|
+
assert encoder_hidden_states is not None
|
|
163
|
+
assert isinstance(image_rotary_emb, tuple)
|
|
164
|
+
qkv_context = fused_qkv_norm_rottary(
|
|
165
|
+
encoder_hidden_states,
|
|
166
|
+
attn.add_qkv_proj,
|
|
167
|
+
attn.norm_added_q,
|
|
168
|
+
attn.norm_added_k,
|
|
169
|
+
image_rotary_emb[1],
|
|
170
|
+
)
|
|
171
|
+
qkv = torch.cat([qkv_context, qkv], dim=1)
|
|
172
|
+
|
|
173
|
+
query, key, value = qkv.chunk(3, dim=-1)
|
|
174
|
+
# Original implementation:
|
|
175
|
+
# query = query.view(batch_size, -1, attn.heads, attn.head_dim).transpose(
|
|
176
|
+
# 1, 2
|
|
177
|
+
# )
|
|
178
|
+
# key = key.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
|
|
179
|
+
# value = value.view(batch_size, -1, attn.heads, attn.head_dim).transpose(
|
|
180
|
+
# 1, 2
|
|
181
|
+
# )
|
|
182
|
+
# hidden_states = F.scaled_dot_product_attention(
|
|
183
|
+
# query,
|
|
184
|
+
# key,
|
|
185
|
+
# value,
|
|
186
|
+
# attn_mask=attention_mask,
|
|
187
|
+
# dropout_p=0.0,
|
|
188
|
+
# is_causal=False,
|
|
189
|
+
# )
|
|
190
|
+
# hidden_states = hidden_states.transpose(1, 2).reshape(
|
|
191
|
+
# batch_size, -1, attn.heads * attn.head_dim
|
|
192
|
+
# )
|
|
193
|
+
# hidden_states = hidden_states.to(query.dtype)
|
|
194
|
+
|
|
195
|
+
# NOTE(DefTruth): Monkey patch to support context parallelism
|
|
196
|
+
query = query.view(batch_size, -1, attn.heads, attn.head_dim)
|
|
197
|
+
key = key.view(batch_size, -1, attn.heads, attn.head_dim)
|
|
198
|
+
value = value.view(batch_size, -1, attn.heads, attn.head_dim)
|
|
199
|
+
|
|
200
|
+
hidden_states = dispatch_attention_fn(
|
|
201
|
+
query,
|
|
202
|
+
key,
|
|
203
|
+
value,
|
|
204
|
+
attn_mask=attention_mask,
|
|
205
|
+
backend=getattr(self, "_attention_backend", None),
|
|
206
|
+
parallel_config=getattr(self, "_parallel_config", None),
|
|
207
|
+
)
|
|
208
|
+
hidden_states = hidden_states.flatten(2, 3)
|
|
209
|
+
hidden_states = hidden_states.to(query.dtype)
|
|
210
|
+
|
|
211
|
+
if encoder_hidden_states is not None:
|
|
212
|
+
encoder_hidden_states, hidden_states = (
|
|
213
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
|
214
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
|
215
|
+
)
|
|
216
|
+
# linear proj
|
|
217
|
+
hidden_states = attn.to_out[0](hidden_states)
|
|
218
|
+
# dropout
|
|
219
|
+
hidden_states = attn.to_out[1](hidden_states)
|
|
220
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
|
221
|
+
return hidden_states, encoder_hidden_states
|
|
222
|
+
else:
|
|
223
|
+
# for single transformer block, we split the proj_out into two linear layers
|
|
224
|
+
hidden_states = attn.to_out(hidden_states)
|
|
225
|
+
return hidden_states
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@ContextParallelismPlannerRegister.register("NunchakuQwenImage")
|
|
229
|
+
class NunchakuQwenImageContextParallelismPlanner(ContextParallelismPlanner):
|
|
230
|
+
def apply(
|
|
231
|
+
self,
|
|
232
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
233
|
+
**kwargs,
|
|
234
|
+
) -> ContextParallelModelPlan:
|
|
235
|
+
|
|
236
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
237
|
+
|
|
238
|
+
if (
|
|
239
|
+
transformer is not None
|
|
240
|
+
and self._cp_planner_preferred_native_diffusers
|
|
241
|
+
):
|
|
242
|
+
|
|
243
|
+
assert isinstance(
|
|
244
|
+
transformer, NunchakuQwenImageTransformer2DModel
|
|
245
|
+
), "Transformer must be an instance of NunchakuQwenImageTransformer2DModel"
|
|
246
|
+
if hasattr(transformer, "_cp_plan"):
|
|
247
|
+
if transformer._cp_plan is not None:
|
|
248
|
+
return transformer._cp_plan
|
|
249
|
+
|
|
250
|
+
NunchakuQwenImageNaiveFA2Processor.__call__ = (
|
|
251
|
+
__patch_NunchakuQwenImageNaiveFA2Processor__call__
|
|
252
|
+
)
|
|
253
|
+
# Also need to patch the parallel config and attention backend
|
|
254
|
+
if not hasattr(NunchakuQwenImageNaiveFA2Processor, "_parallel_config"):
|
|
255
|
+
NunchakuQwenImageNaiveFA2Processor._parallel_config = None
|
|
256
|
+
if not hasattr(
|
|
257
|
+
NunchakuQwenImageNaiveFA2Processor, "_attention_backend"
|
|
258
|
+
):
|
|
259
|
+
NunchakuQwenImageNaiveFA2Processor._attention_backend = None
|
|
260
|
+
if not hasattr(NunchakuQwenAttention, "_parallel_config"):
|
|
261
|
+
NunchakuQwenAttention._parallel_config = None
|
|
262
|
+
if not hasattr(NunchakuQwenAttention, "_attention_backend"):
|
|
263
|
+
NunchakuQwenAttention._attention_backend = None
|
|
264
|
+
|
|
265
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
266
|
+
# a little different from the native diffusers implementation
|
|
267
|
+
# for some models.
|
|
268
|
+
_cp_plan = {
|
|
269
|
+
# Here is a Transformer level CP plan for Flux, which will
|
|
270
|
+
# only apply the only 1 split hook (pre_forward) on the forward
|
|
271
|
+
# of Transformer, and gather the output after Transformer forward.
|
|
272
|
+
# Pattern of transformer forward, split_output=False:
|
|
273
|
+
# un-split input -> splited input (inside transformer)
|
|
274
|
+
# Pattern of the transformer_blocks, single_transformer_blocks:
|
|
275
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
276
|
+
# -> all2all
|
|
277
|
+
# -> attn (local head, full seqlen)
|
|
278
|
+
# -> all2all
|
|
279
|
+
# -> splited output
|
|
280
|
+
# The `hidden_states` and `encoder_hidden_states` will still keep
|
|
281
|
+
# itself splited after block forward (namely, automatic split by
|
|
282
|
+
# the all2all comm op after attn) for the all blocks.
|
|
283
|
+
"": {
|
|
284
|
+
"hidden_states": ContextParallelInput(
|
|
285
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
286
|
+
),
|
|
287
|
+
# NOTE: Due to the joint attention implementation of
|
|
288
|
+
# QwenImageTransformerBlock, we must split the
|
|
289
|
+
# encoder_hidden_states as well.
|
|
290
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
291
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
292
|
+
),
|
|
293
|
+
# NOTE: But encoder_hidden_states_mask seems never used in
|
|
294
|
+
# QwenImageTransformerBlock, so we do not split it here.
|
|
295
|
+
# "encoder_hidden_states_mask": ContextParallelInput(
|
|
296
|
+
# split_dim=1, expected_dims=2, split_output=False
|
|
297
|
+
# ),
|
|
298
|
+
},
|
|
299
|
+
# Pattern of pos_embed, split_output=True (split output rather than input):
|
|
300
|
+
# un-split input
|
|
301
|
+
# -> keep input un-split
|
|
302
|
+
# -> rope
|
|
303
|
+
# -> splited output
|
|
304
|
+
"pos_embed": {
|
|
305
|
+
0: ContextParallelInput(
|
|
306
|
+
split_dim=0, expected_dims=2, split_output=True
|
|
307
|
+
),
|
|
308
|
+
1: ContextParallelInput(
|
|
309
|
+
split_dim=0, expected_dims=2, split_output=True
|
|
310
|
+
),
|
|
311
|
+
},
|
|
312
|
+
# Then, the final proj_out will gather the splited output.
|
|
313
|
+
# splited input (previous splited output)
|
|
314
|
+
# -> all gather
|
|
315
|
+
# -> un-split output
|
|
316
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
317
|
+
}
|
|
318
|
+
return _cp_plan
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
@functools.wraps(NunchakuQwenImageNaiveFA2Processor.__call__)
|
|
322
|
+
def __patch_NunchakuQwenImageNaiveFA2Processor__call__(
|
|
323
|
+
self,
|
|
324
|
+
attn,
|
|
325
|
+
hidden_states: torch.FloatTensor,
|
|
326
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
|
327
|
+
encoder_hidden_states_mask: torch.FloatTensor = None,
|
|
328
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
|
329
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
330
|
+
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
|
|
331
|
+
if encoder_hidden_states is None:
|
|
332
|
+
raise ValueError(
|
|
333
|
+
"NunchakuQwenImageFA2Processor requires encoder_hidden_states (text stream)"
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
seq_txt = encoder_hidden_states.shape[1]
|
|
337
|
+
|
|
338
|
+
# Compute QKV for image stream (sample projections)
|
|
339
|
+
img_qkv = attn.to_qkv(hidden_states)
|
|
340
|
+
img_query, img_key, img_value = img_qkv.chunk(3, dim=-1)
|
|
341
|
+
|
|
342
|
+
# Compute QKV for text stream (context projections)
|
|
343
|
+
txt_qkv = attn.add_qkv_proj(encoder_hidden_states)
|
|
344
|
+
txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1)
|
|
345
|
+
|
|
346
|
+
# Reshape for multi-head attention
|
|
347
|
+
img_query = img_query.unflatten(-1, (attn.heads, -1)) # [B, L, H, D]
|
|
348
|
+
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
|
349
|
+
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
|
350
|
+
|
|
351
|
+
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
|
352
|
+
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
|
353
|
+
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
|
354
|
+
|
|
355
|
+
# Apply QK normalization
|
|
356
|
+
assert attn.norm_q is not None
|
|
357
|
+
img_query = attn.norm_q(img_query)
|
|
358
|
+
assert attn.norm_k is not None
|
|
359
|
+
img_key = attn.norm_k(img_key)
|
|
360
|
+
assert attn.norm_added_q is not None
|
|
361
|
+
txt_query = attn.norm_added_q(txt_query)
|
|
362
|
+
assert attn.norm_added_k is not None
|
|
363
|
+
txt_key = attn.norm_added_k(txt_key)
|
|
364
|
+
|
|
365
|
+
# Apply rotary embeddings
|
|
366
|
+
if image_rotary_emb is not None:
|
|
367
|
+
img_freqs, txt_freqs = image_rotary_emb
|
|
368
|
+
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
|
|
369
|
+
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
|
|
370
|
+
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
|
|
371
|
+
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
|
|
372
|
+
|
|
373
|
+
# Concatenate for joint attention: [text, image]
|
|
374
|
+
joint_query = torch.cat([txt_query, img_query], dim=1)
|
|
375
|
+
joint_key = torch.cat([txt_key, img_key], dim=1)
|
|
376
|
+
joint_value = torch.cat([txt_value, img_value], dim=1)
|
|
377
|
+
|
|
378
|
+
# Compute joint attention
|
|
379
|
+
joint_hidden_states = dispatch_attention_fn(
|
|
380
|
+
joint_query,
|
|
381
|
+
joint_key,
|
|
382
|
+
joint_value,
|
|
383
|
+
attn_mask=attention_mask,
|
|
384
|
+
dropout_p=0.0,
|
|
385
|
+
is_causal=False,
|
|
386
|
+
# NOTE(DefTruth): Use the patched attention backend and
|
|
387
|
+
# parallel config to make context parallelism work here.
|
|
388
|
+
backend=getattr(self, "_attention_backend", None),
|
|
389
|
+
parallel_config=getattr(self, "_parallel_config", None),
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
# Reshape back
|
|
393
|
+
joint_hidden_states = joint_hidden_states.flatten(2, 3)
|
|
394
|
+
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
|
|
395
|
+
|
|
396
|
+
# Split attention outputs back
|
|
397
|
+
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
|
|
398
|
+
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
|
399
|
+
|
|
400
|
+
# Apply output projections
|
|
401
|
+
img_attn_output = attn.to_out[0](img_attn_output)
|
|
402
|
+
if len(attn.to_out) > 1:
|
|
403
|
+
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
|
404
|
+
|
|
405
|
+
txt_attn_output = attn.to_add_out(txt_attn_output)
|
|
406
|
+
|
|
407
|
+
return img_attn_output, txt_attn_output
|