cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/offload_utils.py +115 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/caching/cache_interface.py +358 -0
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
- cache_dit/caching/utils.py +68 -0
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/metrics/metrics.py +3 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- cache_dit-1.0.14.dist-info/METADATA +301 -0
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
- cache_dit/cache_factory/cache_blocks/utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
- cache_dit/cache_factory/cache_interface.py +0 -217
- cache_dit/cache_factory/patch_functors/__init__.py +0 -12
- cache_dit/cache_factory/utils.py +0 -57
- cache_dit-0.3.2.dist-info/METADATA +0 -753
- cache_dit-0.3.2.dist-info/RECORD +0 -56
- cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from diffusers.models.attention_dispatch import (
|
|
7
|
+
_AttentionBackendRegistry,
|
|
8
|
+
AttentionBackendName,
|
|
9
|
+
_check_device,
|
|
10
|
+
_check_shape,
|
|
11
|
+
TemplatedRingAttention,
|
|
12
|
+
TemplatedUlyssesAttention,
|
|
13
|
+
)
|
|
14
|
+
from diffusers.models._modeling_parallel import ParallelConfig
|
|
15
|
+
except ImportError:
|
|
16
|
+
raise ImportError(
|
|
17
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
18
|
+
"Please install latest version of diffusers from source: \n"
|
|
19
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
20
|
+
)
|
|
21
|
+
from cache_dit.logger import init_logger
|
|
22
|
+
|
|
23
|
+
logger = init_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"_native_attention",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
# Enable custom native attention backend with context parallelism
|
|
31
|
+
# by default. Users can set the environment variable to 0 to disable
|
|
32
|
+
# this behavior. Default to enabled for better compatibility.
|
|
33
|
+
_CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH = bool(
|
|
34
|
+
int(os.getenv("CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH", "1"))
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _is_native_attn_supported_context_parallel() -> bool:
|
|
39
|
+
try:
|
|
40
|
+
return (
|
|
41
|
+
AttentionBackendName.NATIVE
|
|
42
|
+
in _AttentionBackendRegistry._supports_context_parallel
|
|
43
|
+
and _AttentionBackendRegistry._supports_context_parallel[
|
|
44
|
+
AttentionBackendName.NATIVE
|
|
45
|
+
]
|
|
46
|
+
)
|
|
47
|
+
except Exception:
|
|
48
|
+
assert isinstance(
|
|
49
|
+
_AttentionBackendRegistry._supports_context_parallel, set
|
|
50
|
+
)
|
|
51
|
+
return (
|
|
52
|
+
AttentionBackendName.NATIVE.value
|
|
53
|
+
in _AttentionBackendRegistry._supports_context_parallel
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
if _CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH:
|
|
58
|
+
logger.warning(
|
|
59
|
+
"Re-registering NATIVE attention backend to enable context parallelism. "
|
|
60
|
+
"This is a temporary workaround and should be removed after the native "
|
|
61
|
+
"attention backend supports context parallelism natively. Please check: "
|
|
62
|
+
"https://github.com/huggingface/diffusers/pull/12563 for more details. "
|
|
63
|
+
"Or, you can disable this behavior by setting the environment variable "
|
|
64
|
+
"`CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH=0`."
|
|
65
|
+
)
|
|
66
|
+
_AttentionBackendRegistry._backends.pop(AttentionBackendName.NATIVE)
|
|
67
|
+
_AttentionBackendRegistry._constraints.pop(AttentionBackendName.NATIVE)
|
|
68
|
+
_AttentionBackendRegistry._supported_arg_names.pop(
|
|
69
|
+
AttentionBackendName.NATIVE
|
|
70
|
+
)
|
|
71
|
+
if _is_native_attn_supported_context_parallel():
|
|
72
|
+
if isinstance(
|
|
73
|
+
_AttentionBackendRegistry._supports_context_parallel, dict
|
|
74
|
+
):
|
|
75
|
+
_AttentionBackendRegistry._supports_context_parallel.pop(
|
|
76
|
+
AttentionBackendName.NATIVE
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
_AttentionBackendRegistry._supports_context_parallel.remove(
|
|
80
|
+
AttentionBackendName.NATIVE.value
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Re-define templated context parallel attention to support attn mask
|
|
84
|
+
def _templated_context_parallel_attention_v2(
|
|
85
|
+
query: torch.Tensor,
|
|
86
|
+
key: torch.Tensor,
|
|
87
|
+
value: torch.Tensor,
|
|
88
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
89
|
+
dropout_p: float = 0.0,
|
|
90
|
+
is_causal: bool = False,
|
|
91
|
+
scale: Optional[float] = None,
|
|
92
|
+
enable_gqa: bool = False,
|
|
93
|
+
return_lse: bool = False,
|
|
94
|
+
*,
|
|
95
|
+
forward_op,
|
|
96
|
+
backward_op,
|
|
97
|
+
_parallel_config: Optional["ParallelConfig"] = None,
|
|
98
|
+
):
|
|
99
|
+
if attn_mask is not None:
|
|
100
|
+
# NOTE(DefTruth): Check if forward_op is native attention forward op
|
|
101
|
+
forward_op_name = forward_op.__name__
|
|
102
|
+
if not forward_op_name == "_native_attention_forward_op":
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"Templated context parallel attention with attn_mask "
|
|
105
|
+
"is only supported for native attention backend, "
|
|
106
|
+
f"but got forward_op: {forward_op_name}."
|
|
107
|
+
)
|
|
108
|
+
if is_causal:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"Causal attention is not yet supported for templated attention."
|
|
111
|
+
)
|
|
112
|
+
if enable_gqa:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
"GQA is not yet supported for templated attention."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# TODO: add support for unified attention with ring/ulysses degree both being > 1
|
|
118
|
+
if _parallel_config.context_parallel_config.ring_degree > 1:
|
|
119
|
+
return TemplatedRingAttention.apply(
|
|
120
|
+
query,
|
|
121
|
+
key,
|
|
122
|
+
value,
|
|
123
|
+
attn_mask,
|
|
124
|
+
dropout_p,
|
|
125
|
+
is_causal,
|
|
126
|
+
scale,
|
|
127
|
+
enable_gqa,
|
|
128
|
+
return_lse,
|
|
129
|
+
forward_op,
|
|
130
|
+
backward_op,
|
|
131
|
+
_parallel_config,
|
|
132
|
+
)
|
|
133
|
+
elif _parallel_config.context_parallel_config.ulysses_degree > 1:
|
|
134
|
+
return TemplatedUlyssesAttention.apply(
|
|
135
|
+
query,
|
|
136
|
+
key,
|
|
137
|
+
value,
|
|
138
|
+
attn_mask,
|
|
139
|
+
dropout_p,
|
|
140
|
+
is_causal,
|
|
141
|
+
scale,
|
|
142
|
+
enable_gqa,
|
|
143
|
+
return_lse,
|
|
144
|
+
forward_op,
|
|
145
|
+
backward_op,
|
|
146
|
+
_parallel_config,
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
"Reaching this branch of code is unexpected. Please report a bug."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# NOTE:Remove NATIVE attention backend constraints and re-register it.
|
|
154
|
+
# Here is a temporary workaround to enable context parallelism with
|
|
155
|
+
# native attention backend. We should remove this workaround after
|
|
156
|
+
# the native attention backend supports context parallelism natively.
|
|
157
|
+
# Adapted from: https://github.com/huggingface/diffusers/pull/12563
|
|
158
|
+
|
|
159
|
+
def _native_attention_forward_op(
|
|
160
|
+
ctx: torch.autograd.function.FunctionCtx,
|
|
161
|
+
query: torch.Tensor,
|
|
162
|
+
key: torch.Tensor,
|
|
163
|
+
value: torch.Tensor,
|
|
164
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
165
|
+
dropout_p: float = 0.0,
|
|
166
|
+
is_causal: bool = False,
|
|
167
|
+
scale: Optional[float] = None,
|
|
168
|
+
enable_gqa: bool = False,
|
|
169
|
+
return_lse: bool = False,
|
|
170
|
+
_save_ctx: bool = True,
|
|
171
|
+
_parallel_config: Optional["ParallelConfig"] = None,
|
|
172
|
+
):
|
|
173
|
+
# Native attention does not return_lse
|
|
174
|
+
if return_lse:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
"Native attention does not support return_lse=True"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# used for backward pass
|
|
180
|
+
if _save_ctx:
|
|
181
|
+
ctx.save_for_backward(query, key, value)
|
|
182
|
+
ctx.attn_mask = attn_mask
|
|
183
|
+
ctx.dropout_p = dropout_p
|
|
184
|
+
ctx.is_causal = is_causal
|
|
185
|
+
ctx.scale = scale
|
|
186
|
+
ctx.enable_gqa = enable_gqa
|
|
187
|
+
|
|
188
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
|
189
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
|
190
|
+
query=query,
|
|
191
|
+
key=key,
|
|
192
|
+
value=value,
|
|
193
|
+
attn_mask=attn_mask,
|
|
194
|
+
dropout_p=dropout_p,
|
|
195
|
+
is_causal=is_causal,
|
|
196
|
+
scale=scale,
|
|
197
|
+
enable_gqa=enable_gqa,
|
|
198
|
+
)
|
|
199
|
+
out = out.permute(0, 2, 1, 3)
|
|
200
|
+
|
|
201
|
+
return out
|
|
202
|
+
|
|
203
|
+
def _native_attention_backward_op(
|
|
204
|
+
ctx: torch.autograd.function.FunctionCtx,
|
|
205
|
+
grad_out: torch.Tensor,
|
|
206
|
+
*args,
|
|
207
|
+
**kwargs,
|
|
208
|
+
):
|
|
209
|
+
query, key, value = ctx.saved_tensors
|
|
210
|
+
|
|
211
|
+
query.requires_grad_(True)
|
|
212
|
+
key.requires_grad_(True)
|
|
213
|
+
value.requires_grad_(True)
|
|
214
|
+
|
|
215
|
+
query_t, key_t, value_t = (
|
|
216
|
+
x.permute(0, 2, 1, 3) for x in (query, key, value)
|
|
217
|
+
)
|
|
218
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
|
219
|
+
query=query_t,
|
|
220
|
+
key=key_t,
|
|
221
|
+
value=value_t,
|
|
222
|
+
attn_mask=ctx.attn_mask,
|
|
223
|
+
dropout_p=ctx.dropout_p,
|
|
224
|
+
is_causal=ctx.is_causal,
|
|
225
|
+
scale=ctx.scale,
|
|
226
|
+
enable_gqa=ctx.enable_gqa,
|
|
227
|
+
)
|
|
228
|
+
out = out.permute(0, 2, 1, 3)
|
|
229
|
+
|
|
230
|
+
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
|
231
|
+
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
|
232
|
+
outputs=out,
|
|
233
|
+
inputs=[query_t, key_t, value_t],
|
|
234
|
+
grad_outputs=grad_out_t,
|
|
235
|
+
retain_graph=False,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
|
239
|
+
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
|
240
|
+
grad_value = grad_value_t.permute(0, 2, 1, 3)
|
|
241
|
+
|
|
242
|
+
return grad_query, grad_key, grad_value
|
|
243
|
+
|
|
244
|
+
@_AttentionBackendRegistry.register(
|
|
245
|
+
AttentionBackendName.NATIVE,
|
|
246
|
+
constraints=[_check_device, _check_shape],
|
|
247
|
+
supports_context_parallel=True,
|
|
248
|
+
)
|
|
249
|
+
def _native_attention(
|
|
250
|
+
query: torch.Tensor,
|
|
251
|
+
key: torch.Tensor,
|
|
252
|
+
value: torch.Tensor,
|
|
253
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
254
|
+
dropout_p: float = 0.0,
|
|
255
|
+
is_causal: bool = False,
|
|
256
|
+
scale: Optional[float] = None,
|
|
257
|
+
enable_gqa: bool = False,
|
|
258
|
+
return_lse: bool = False,
|
|
259
|
+
_parallel_config: Optional["ParallelConfig"] = None,
|
|
260
|
+
) -> torch.Tensor:
|
|
261
|
+
if return_lse:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"Native attention backend does not support setting `return_lse=True`."
|
|
264
|
+
)
|
|
265
|
+
if _parallel_config is None:
|
|
266
|
+
query, key, value = (
|
|
267
|
+
x.permute(0, 2, 1, 3) for x in (query, key, value)
|
|
268
|
+
)
|
|
269
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
|
270
|
+
query=query,
|
|
271
|
+
key=key,
|
|
272
|
+
value=value,
|
|
273
|
+
attn_mask=attn_mask,
|
|
274
|
+
dropout_p=dropout_p,
|
|
275
|
+
is_causal=is_causal,
|
|
276
|
+
scale=scale,
|
|
277
|
+
enable_gqa=enable_gqa,
|
|
278
|
+
)
|
|
279
|
+
out = out.permute(0, 2, 1, 3)
|
|
280
|
+
else:
|
|
281
|
+
out = _templated_context_parallel_attention_v2(
|
|
282
|
+
query,
|
|
283
|
+
key,
|
|
284
|
+
value,
|
|
285
|
+
attn_mask,
|
|
286
|
+
dropout_p,
|
|
287
|
+
is_causal,
|
|
288
|
+
scale,
|
|
289
|
+
enable_gqa,
|
|
290
|
+
return_lse,
|
|
291
|
+
forward_op=_native_attention_forward_op,
|
|
292
|
+
backward_op=_native_attention_backward_op,
|
|
293
|
+
_parallel_config=_parallel_config,
|
|
294
|
+
)
|
|
295
|
+
return out
|
|
296
|
+
|
|
297
|
+
else:
|
|
298
|
+
from diffusers.models.attention_dispatch import (
|
|
299
|
+
_native_attention,
|
|
300
|
+
) # noqa: F401
|
|
301
|
+
|
|
302
|
+
logger.info(
|
|
303
|
+
"Native attention backend already supports context parallelism."
|
|
304
|
+
)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
4
|
+
from diffusers.models.transformers.transformer_chroma import (
|
|
5
|
+
ChromaTransformer2DModel,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from diffusers.models._modeling_parallel import (
|
|
10
|
+
ContextParallelInput,
|
|
11
|
+
ContextParallelOutput,
|
|
12
|
+
ContextParallelModelPlan,
|
|
13
|
+
)
|
|
14
|
+
except ImportError:
|
|
15
|
+
raise ImportError(
|
|
16
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
17
|
+
"Please install latest version of diffusers from source: \n"
|
|
18
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
19
|
+
)
|
|
20
|
+
from .cp_plan_registers import (
|
|
21
|
+
ContextParallelismPlanner,
|
|
22
|
+
ContextParallelismPlannerRegister,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from cache_dit.logger import init_logger
|
|
26
|
+
|
|
27
|
+
logger = init_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@ContextParallelismPlannerRegister.register("Chroma")
|
|
31
|
+
class ChromaContextParallelismPlanner(ContextParallelismPlanner):
|
|
32
|
+
def apply(
|
|
33
|
+
self,
|
|
34
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
) -> ContextParallelModelPlan:
|
|
37
|
+
|
|
38
|
+
# NOTE: Diffusers native CP plan still not supported
|
|
39
|
+
# for Chroma now.
|
|
40
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
41
|
+
|
|
42
|
+
if (
|
|
43
|
+
transformer is not None
|
|
44
|
+
and self._cp_planner_preferred_native_diffusers
|
|
45
|
+
):
|
|
46
|
+
assert isinstance(
|
|
47
|
+
transformer, ChromaTransformer2DModel
|
|
48
|
+
), "Transformer must be an instance of ChromaTransformer2DModel"
|
|
49
|
+
if hasattr(transformer, "_cp_plan"):
|
|
50
|
+
if transformer._cp_plan is not None:
|
|
51
|
+
return transformer._cp_plan
|
|
52
|
+
|
|
53
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
54
|
+
# a little different from the native diffusers implementation
|
|
55
|
+
# for some models.
|
|
56
|
+
_cp_plan = {
|
|
57
|
+
# Here is a Transformer level CP plan for Chroma, which will
|
|
58
|
+
# only apply the only 1 split hook (pre_forward) on the forward
|
|
59
|
+
# of Transformer, and gather the output after Transformer forward.
|
|
60
|
+
# Pattern of transformer forward, split_output=False:
|
|
61
|
+
# un-split input -> splited input (inside transformer)
|
|
62
|
+
# Pattern of the transformer_blocks, single_transformer_blocks:
|
|
63
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
64
|
+
# -> all2all
|
|
65
|
+
# -> attn (local head, full seqlen)
|
|
66
|
+
# -> all2all
|
|
67
|
+
# -> splited output
|
|
68
|
+
# The `hidden_states` and `encoder_hidden_states` will still keep
|
|
69
|
+
# itself splited after block forward (namely, automatic split by
|
|
70
|
+
# the all2all comm op after attn) for the all blocks.
|
|
71
|
+
# img_ids and txt_ids will only be splited once at the very beginning,
|
|
72
|
+
# and keep splited through the whole transformer forward. The all2all
|
|
73
|
+
# comm op only happens on the `out` tensor after local attn not on
|
|
74
|
+
# img_ids and txt_ids.
|
|
75
|
+
"": {
|
|
76
|
+
"hidden_states": ContextParallelInput(
|
|
77
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
78
|
+
),
|
|
79
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
80
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
81
|
+
),
|
|
82
|
+
"img_ids": ContextParallelInput(
|
|
83
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
84
|
+
),
|
|
85
|
+
"txt_ids": ContextParallelInput(
|
|
86
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
87
|
+
),
|
|
88
|
+
},
|
|
89
|
+
# Then, the final proj_out will gather the splited output.
|
|
90
|
+
# splited input (previous splited output)
|
|
91
|
+
# -> all gather
|
|
92
|
+
# -> un-split output
|
|
93
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
94
|
+
}
|
|
95
|
+
return _cp_plan
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import functools
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
5
|
+
from diffusers.models.transformers.cogvideox_transformer_3d import (
|
|
6
|
+
CogVideoXAttnProcessor2_0,
|
|
7
|
+
CogVideoXTransformer3DModel,
|
|
8
|
+
)
|
|
9
|
+
from diffusers.models.attention_processor import Attention
|
|
10
|
+
from diffusers.models.attention_dispatch import dispatch_attention_fn
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from diffusers.models._modeling_parallel import (
|
|
14
|
+
ContextParallelInput,
|
|
15
|
+
ContextParallelOutput,
|
|
16
|
+
ContextParallelModelPlan,
|
|
17
|
+
)
|
|
18
|
+
except ImportError:
|
|
19
|
+
raise ImportError(
|
|
20
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
21
|
+
"Please install latest version of diffusers from source: \n"
|
|
22
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
23
|
+
)
|
|
24
|
+
from .cp_plan_registers import (
|
|
25
|
+
ContextParallelismPlanner,
|
|
26
|
+
ContextParallelismPlannerRegister,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from cache_dit.logger import init_logger
|
|
30
|
+
|
|
31
|
+
logger = init_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@ContextParallelismPlannerRegister.register("CogVideoX")
|
|
35
|
+
class CogVideoXContextParallelismPlanner(ContextParallelismPlanner):
|
|
36
|
+
def apply(
|
|
37
|
+
self,
|
|
38
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
39
|
+
**kwargs,
|
|
40
|
+
) -> ContextParallelModelPlan:
|
|
41
|
+
|
|
42
|
+
# NOTE: Diffusers native CP plan still not supported
|
|
43
|
+
# for CogVideoX now.
|
|
44
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
45
|
+
|
|
46
|
+
if (
|
|
47
|
+
transformer is not None
|
|
48
|
+
and self._cp_planner_preferred_native_diffusers
|
|
49
|
+
):
|
|
50
|
+
assert isinstance(
|
|
51
|
+
transformer, CogVideoXTransformer3DModel
|
|
52
|
+
), "Transformer must be an instance of CogVideoXTransformer3DModel"
|
|
53
|
+
if hasattr(transformer, "_cp_plan"):
|
|
54
|
+
if transformer._cp_plan is not None:
|
|
55
|
+
return transformer._cp_plan
|
|
56
|
+
|
|
57
|
+
CogVideoXAttnProcessor2_0.__call__ = (
|
|
58
|
+
__patch_CogVideoXAttnProcessor2_0__call__
|
|
59
|
+
)
|
|
60
|
+
# Also need to patch the parallel config and attention backend
|
|
61
|
+
if not hasattr(CogVideoXAttnProcessor2_0, "_parallel_config"):
|
|
62
|
+
CogVideoXAttnProcessor2_0._parallel_config = None
|
|
63
|
+
if not hasattr(CogVideoXAttnProcessor2_0, "_attention_backend"):
|
|
64
|
+
CogVideoXAttnProcessor2_0._attention_backend = None
|
|
65
|
+
|
|
66
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
67
|
+
# a little different from the native diffusers implementation
|
|
68
|
+
# for some models.
|
|
69
|
+
_cp_plan = {
|
|
70
|
+
# Pattern of transformer_blocks.0, split_output=False:
|
|
71
|
+
# un-split input -> split -> to_qkv/...
|
|
72
|
+
# -> all2all
|
|
73
|
+
# -> attn (local head, full seqlen)
|
|
74
|
+
# -> all2all
|
|
75
|
+
# -> splited output
|
|
76
|
+
# Pattern of the rest transformer_blocks, split_output=False:
|
|
77
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
78
|
+
# -> all2all
|
|
79
|
+
# -> attn (local head, full seqlen)
|
|
80
|
+
# -> all2all
|
|
81
|
+
# -> splited output
|
|
82
|
+
# The `encoder_hidden_states` will be changed after each block forward,
|
|
83
|
+
# so we need to split it at the first block, and keep it splited (namely,
|
|
84
|
+
# automatically split by the all2all op after attn) for the rest blocks.
|
|
85
|
+
# The `out` tensor of local attn will be splited into `hidden_states` and
|
|
86
|
+
# `encoder_hidden_states` after each block forward, thus both of them
|
|
87
|
+
# will be automatically splited by all2all comm op after local attn.
|
|
88
|
+
"transformer_blocks.0": {
|
|
89
|
+
"hidden_states": ContextParallelInput(
|
|
90
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
91
|
+
),
|
|
92
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
93
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
94
|
+
),
|
|
95
|
+
},
|
|
96
|
+
# Pattern of the image_rotary_emb, split at every block, because the it
|
|
97
|
+
# is not automatically splited by all2all comm op and keep un-splited
|
|
98
|
+
# while the block forward finished:
|
|
99
|
+
# un-split input -> split output
|
|
100
|
+
# -> after block forward
|
|
101
|
+
# -> un-split input
|
|
102
|
+
# un-split input -> split output
|
|
103
|
+
# ...
|
|
104
|
+
"transformer_blocks.*": {
|
|
105
|
+
"image_rotary_emb": [
|
|
106
|
+
ContextParallelInput(
|
|
107
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
108
|
+
),
|
|
109
|
+
ContextParallelInput(
|
|
110
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
111
|
+
),
|
|
112
|
+
],
|
|
113
|
+
},
|
|
114
|
+
# transformer forward while using CP, since it is not splited here.
|
|
115
|
+
# Then, the final proj_out will gather the splited output.
|
|
116
|
+
# splited input (previous splited output)
|
|
117
|
+
# -> all gather
|
|
118
|
+
# -> un-split output
|
|
119
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
120
|
+
}
|
|
121
|
+
return _cp_plan
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@functools.wraps(CogVideoXAttnProcessor2_0.__call__)
|
|
125
|
+
def __patch_CogVideoXAttnProcessor2_0__call__(
|
|
126
|
+
self: CogVideoXAttnProcessor2_0,
|
|
127
|
+
attn: Attention,
|
|
128
|
+
hidden_states: torch.Tensor,
|
|
129
|
+
encoder_hidden_states: torch.Tensor,
|
|
130
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
131
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
132
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
133
|
+
text_seq_length = encoder_hidden_states.size(1)
|
|
134
|
+
|
|
135
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
136
|
+
|
|
137
|
+
batch_size, sequence_length, _ = hidden_states.shape
|
|
138
|
+
|
|
139
|
+
# NOTE(DefTruth): attention mask is always None in CogVideoX
|
|
140
|
+
if attention_mask is not None:
|
|
141
|
+
attention_mask = attn.prepare_attention_mask(
|
|
142
|
+
attention_mask, sequence_length, batch_size
|
|
143
|
+
)
|
|
144
|
+
attention_mask = attention_mask.view(
|
|
145
|
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
query = attn.to_q(hidden_states)
|
|
149
|
+
key = attn.to_k(hidden_states)
|
|
150
|
+
value = attn.to_v(hidden_states)
|
|
151
|
+
|
|
152
|
+
inner_dim = key.shape[-1]
|
|
153
|
+
head_dim = inner_dim // attn.heads
|
|
154
|
+
|
|
155
|
+
# NOTE(DefTruth): no transpose
|
|
156
|
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
|
157
|
+
key = key.view(batch_size, -1, attn.heads, head_dim)
|
|
158
|
+
value = value.view(batch_size, -1, attn.heads, head_dim)
|
|
159
|
+
|
|
160
|
+
if attn.norm_q is not None:
|
|
161
|
+
query = attn.norm_q(query)
|
|
162
|
+
if attn.norm_k is not None:
|
|
163
|
+
key = attn.norm_k(key)
|
|
164
|
+
|
|
165
|
+
# Apply RoPE if needed
|
|
166
|
+
if image_rotary_emb is not None:
|
|
167
|
+
from diffusers.models.embeddings import apply_rotary_emb
|
|
168
|
+
|
|
169
|
+
query[:, text_seq_length:] = apply_rotary_emb(
|
|
170
|
+
query[:, text_seq_length:],
|
|
171
|
+
image_rotary_emb,
|
|
172
|
+
sequence_dim=1,
|
|
173
|
+
)
|
|
174
|
+
if not attn.is_cross_attention:
|
|
175
|
+
key[:, text_seq_length:] = apply_rotary_emb(
|
|
176
|
+
key[:, text_seq_length:],
|
|
177
|
+
image_rotary_emb,
|
|
178
|
+
sequence_dim=1,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# NOTE(DefTruth): Apply dispatch_attention_fn instead of sdpa directly
|
|
182
|
+
hidden_states = dispatch_attention_fn(
|
|
183
|
+
query,
|
|
184
|
+
key,
|
|
185
|
+
value,
|
|
186
|
+
attn_mask=attention_mask,
|
|
187
|
+
dropout_p=0.0,
|
|
188
|
+
is_causal=False,
|
|
189
|
+
backend=getattr(self, "_attention_backend", None),
|
|
190
|
+
parallel_config=getattr(self, "_parallel_config", None),
|
|
191
|
+
)
|
|
192
|
+
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
|
|
193
|
+
|
|
194
|
+
# linear proj
|
|
195
|
+
hidden_states = attn.to_out[0](hidden_states)
|
|
196
|
+
# dropout
|
|
197
|
+
hidden_states = attn.to_out[1](hidden_states)
|
|
198
|
+
|
|
199
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
200
|
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
|
201
|
+
)
|
|
202
|
+
return hidden_states, encoder_hidden_states
|