cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import functools
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
6
|
+
from diffusers.models.transformers.pixart_transformer_2d import (
|
|
7
|
+
PixArtTransformer2DModel,
|
|
8
|
+
)
|
|
9
|
+
from diffusers.models.attention_processor import (
|
|
10
|
+
Attention,
|
|
11
|
+
AttnProcessor2_0,
|
|
12
|
+
) # sdpa
|
|
13
|
+
from diffusers.utils import deprecate
|
|
14
|
+
from diffusers.models.attention_dispatch import dispatch_attention_fn
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from diffusers.models._modeling_parallel import (
|
|
18
|
+
ContextParallelInput,
|
|
19
|
+
ContextParallelOutput,
|
|
20
|
+
ContextParallelModelPlan,
|
|
21
|
+
)
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
25
|
+
"Please install latest version of diffusers from source: \n"
|
|
26
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
27
|
+
)
|
|
28
|
+
from .cp_plan_registers import (
|
|
29
|
+
ContextParallelismPlanner,
|
|
30
|
+
ContextParallelismPlannerRegister,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
from cache_dit.logger import init_logger
|
|
34
|
+
|
|
35
|
+
logger = init_logger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@ContextParallelismPlannerRegister.register("PixArt")
|
|
39
|
+
class PixArtContextParallelismPlanner(ContextParallelismPlanner):
|
|
40
|
+
def apply(
|
|
41
|
+
self,
|
|
42
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
43
|
+
**kwargs,
|
|
44
|
+
) -> ContextParallelModelPlan:
|
|
45
|
+
assert transformer is not None, "Transformer must be provided."
|
|
46
|
+
assert isinstance(
|
|
47
|
+
transformer, PixArtTransformer2DModel
|
|
48
|
+
), "Transformer must be an instance of PixArtTransformer2DModel"
|
|
49
|
+
|
|
50
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
51
|
+
|
|
52
|
+
if (
|
|
53
|
+
transformer is not None
|
|
54
|
+
and self._cp_planner_preferred_native_diffusers
|
|
55
|
+
):
|
|
56
|
+
if hasattr(transformer, "_cp_plan"):
|
|
57
|
+
if transformer._cp_plan is not None:
|
|
58
|
+
return transformer._cp_plan
|
|
59
|
+
|
|
60
|
+
# Apply monkey patch to fix attention mask preparation at class level
|
|
61
|
+
Attention.prepare_attention_mask = (
|
|
62
|
+
__patch_Attention_prepare_attention_mask__
|
|
63
|
+
)
|
|
64
|
+
AttnProcessor2_0.__call__ = __patch_AttnProcessor2_0__call__
|
|
65
|
+
if not hasattr(AttnProcessor2_0, "_parallel_config"):
|
|
66
|
+
AttnProcessor2_0._parallel_config = None
|
|
67
|
+
if not hasattr(AttnProcessor2_0, "_attention_backend"):
|
|
68
|
+
AttnProcessor2_0._attention_backend = None
|
|
69
|
+
|
|
70
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
71
|
+
# a little different from the native diffusers implementation
|
|
72
|
+
# for some models.
|
|
73
|
+
|
|
74
|
+
_cp_plan = {
|
|
75
|
+
# Pattern of transformer_blocks.0, split_output=False:
|
|
76
|
+
# un-split input -> split -> to_qkv/...
|
|
77
|
+
# -> all2all
|
|
78
|
+
# -> attn (local head, full seqlen)
|
|
79
|
+
# -> all2all
|
|
80
|
+
# -> splited output
|
|
81
|
+
# (only split hidden_states, not encoder_hidden_states)
|
|
82
|
+
"transformer_blocks.0": {
|
|
83
|
+
"hidden_states": ContextParallelInput(
|
|
84
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
85
|
+
),
|
|
86
|
+
},
|
|
87
|
+
# Pattern of the all blocks, split_output=False:
|
|
88
|
+
# un-split input -> split -> to_qkv/...
|
|
89
|
+
# -> all2all
|
|
90
|
+
# -> attn (local head, full seqlen)
|
|
91
|
+
# -> all2all
|
|
92
|
+
# -> splited output
|
|
93
|
+
# (only split encoder_hidden_states, not hidden_states.
|
|
94
|
+
# hidden_states has been automatically split in previous
|
|
95
|
+
# block by all2all comm op after attn)
|
|
96
|
+
# The `encoder_hidden_states` will [NOT] be changed after each block forward,
|
|
97
|
+
# so we need to split it at [ALL] block by the inserted split hook.
|
|
98
|
+
"transformer_blocks.*": {
|
|
99
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
100
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
101
|
+
),
|
|
102
|
+
},
|
|
103
|
+
# Then, the final proj_out will gather the splited output.
|
|
104
|
+
# splited input (previous splited output)
|
|
105
|
+
# -> all gather
|
|
106
|
+
# -> un-split output
|
|
107
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
108
|
+
}
|
|
109
|
+
return _cp_plan
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@functools.wraps(Attention.prepare_attention_mask)
|
|
113
|
+
def __patch_Attention_prepare_attention_mask__(
|
|
114
|
+
self: Attention,
|
|
115
|
+
attention_mask: torch.Tensor,
|
|
116
|
+
target_length: int,
|
|
117
|
+
batch_size: int,
|
|
118
|
+
out_dim: int = 3,
|
|
119
|
+
# NOTE(DefTruth): Allow specifying head_size for CP
|
|
120
|
+
head_size: Optional[int] = None,
|
|
121
|
+
) -> torch.Tensor:
|
|
122
|
+
if head_size is None:
|
|
123
|
+
head_size = self.heads
|
|
124
|
+
if attention_mask is None:
|
|
125
|
+
return attention_mask
|
|
126
|
+
|
|
127
|
+
current_length: int = attention_mask.shape[-1]
|
|
128
|
+
if current_length != target_length:
|
|
129
|
+
if attention_mask.device.type == "mps":
|
|
130
|
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
|
131
|
+
# Instead, we can manually construct the padding tensor.
|
|
132
|
+
padding_shape = (
|
|
133
|
+
attention_mask.shape[0],
|
|
134
|
+
attention_mask.shape[1],
|
|
135
|
+
target_length,
|
|
136
|
+
)
|
|
137
|
+
padding = torch.zeros(
|
|
138
|
+
padding_shape,
|
|
139
|
+
dtype=attention_mask.dtype,
|
|
140
|
+
device=attention_mask.device,
|
|
141
|
+
)
|
|
142
|
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
|
143
|
+
else:
|
|
144
|
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
|
145
|
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
|
146
|
+
# remaining_length: int = target_length - current_length
|
|
147
|
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
|
148
|
+
attention_mask = F.pad(
|
|
149
|
+
attention_mask, (0, target_length), value=0.0
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if out_dim == 3:
|
|
153
|
+
if attention_mask.shape[0] < batch_size * head_size:
|
|
154
|
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
|
155
|
+
elif out_dim == 4:
|
|
156
|
+
attention_mask = attention_mask.unsqueeze(1)
|
|
157
|
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
|
158
|
+
|
|
159
|
+
return attention_mask
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@functools.wraps(AttnProcessor2_0.__call__)
|
|
163
|
+
def __patch_AttnProcessor2_0__call__(
|
|
164
|
+
self: AttnProcessor2_0,
|
|
165
|
+
attn: Attention,
|
|
166
|
+
hidden_states: torch.Tensor,
|
|
167
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
168
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
169
|
+
temb: Optional[torch.Tensor] = None,
|
|
170
|
+
*args,
|
|
171
|
+
**kwargs,
|
|
172
|
+
) -> torch.Tensor:
|
|
173
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
|
174
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
|
175
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
|
176
|
+
|
|
177
|
+
residual = hidden_states
|
|
178
|
+
if attn.spatial_norm is not None:
|
|
179
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
180
|
+
|
|
181
|
+
input_ndim = hidden_states.ndim
|
|
182
|
+
|
|
183
|
+
if input_ndim == 4:
|
|
184
|
+
batch_size, channel, height, width = hidden_states.shape
|
|
185
|
+
hidden_states = hidden_states.view(
|
|
186
|
+
batch_size, channel, height * width
|
|
187
|
+
).transpose(1, 2)
|
|
188
|
+
|
|
189
|
+
batch_size, sequence_length, _ = (
|
|
190
|
+
hidden_states.shape
|
|
191
|
+
if encoder_hidden_states is None
|
|
192
|
+
else encoder_hidden_states.shape
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
if attention_mask is not None:
|
|
196
|
+
if self._parallel_config is None:
|
|
197
|
+
attention_mask = attn.prepare_attention_mask(
|
|
198
|
+
attention_mask, sequence_length, batch_size
|
|
199
|
+
)
|
|
200
|
+
attention_mask = attention_mask.view(
|
|
201
|
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
# NOTE(DefTruth): Fix attention mask preparation for context parallelism.
|
|
205
|
+
# Please note that in context parallelism, the sequence_length is the local
|
|
206
|
+
# sequence length on each rank. So we need to adjust the target_length
|
|
207
|
+
# accordingly. The head_size is also adjusted based on the world size
|
|
208
|
+
# in order to make sdpa work correctly, otherwise, the sdpa op will raise
|
|
209
|
+
# error due to the mismatch between attention_mask shape and expected shape.
|
|
210
|
+
cp_config = getattr(
|
|
211
|
+
self._parallel_config, "context_parallel_config", None
|
|
212
|
+
)
|
|
213
|
+
if cp_config is not None and cp_config._world_size > 1:
|
|
214
|
+
head_size = attn.heads // cp_config._world_size
|
|
215
|
+
attention_mask = attn.prepare_attention_mask(
|
|
216
|
+
attention_mask,
|
|
217
|
+
sequence_length * cp_config._world_size,
|
|
218
|
+
batch_size,
|
|
219
|
+
3,
|
|
220
|
+
head_size,
|
|
221
|
+
)
|
|
222
|
+
attention_mask = attention_mask.view(
|
|
223
|
+
batch_size, head_size, -1, attention_mask.shape[-1]
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
if attn.group_norm is not None:
|
|
227
|
+
hidden_states = attn.group_norm(
|
|
228
|
+
hidden_states.transpose(1, 2)
|
|
229
|
+
).transpose(1, 2)
|
|
230
|
+
|
|
231
|
+
query = attn.to_q(hidden_states)
|
|
232
|
+
|
|
233
|
+
if encoder_hidden_states is None:
|
|
234
|
+
encoder_hidden_states = hidden_states
|
|
235
|
+
elif attn.norm_cross:
|
|
236
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
|
237
|
+
encoder_hidden_states
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
key = attn.to_k(encoder_hidden_states)
|
|
241
|
+
value = attn.to_v(encoder_hidden_states)
|
|
242
|
+
|
|
243
|
+
inner_dim = key.shape[-1]
|
|
244
|
+
head_dim = inner_dim // attn.heads
|
|
245
|
+
|
|
246
|
+
# NOTE(DefTruth): no transpose now
|
|
247
|
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
|
248
|
+
key = key.view(batch_size, -1, attn.heads, head_dim)
|
|
249
|
+
value = value.view(batch_size, -1, attn.heads, head_dim)
|
|
250
|
+
|
|
251
|
+
if attn.norm_q is not None:
|
|
252
|
+
query = attn.norm_q(query)
|
|
253
|
+
if attn.norm_k is not None:
|
|
254
|
+
key = attn.norm_k(key)
|
|
255
|
+
|
|
256
|
+
# NOTE(DefTruth): Use the dispatch_attention_fn to support different backends
|
|
257
|
+
hidden_states = dispatch_attention_fn(
|
|
258
|
+
query,
|
|
259
|
+
key,
|
|
260
|
+
value,
|
|
261
|
+
attn_mask=attention_mask,
|
|
262
|
+
dropout_p=0.0,
|
|
263
|
+
is_causal=False,
|
|
264
|
+
backend=getattr(self, "_attention_backend", None),
|
|
265
|
+
parallel_config=getattr(self, "_parallel_config", None),
|
|
266
|
+
)
|
|
267
|
+
hidden_states = hidden_states.flatten(2, 3)
|
|
268
|
+
hidden_states = hidden_states.to(query.dtype)
|
|
269
|
+
|
|
270
|
+
# linear proj
|
|
271
|
+
hidden_states = attn.to_out[0](hidden_states)
|
|
272
|
+
# dropout
|
|
273
|
+
hidden_states = attn.to_out[1](hidden_states)
|
|
274
|
+
|
|
275
|
+
if input_ndim == 4:
|
|
276
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
|
277
|
+
batch_size, channel, height, width
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if attn.residual_connection:
|
|
281
|
+
hidden_states = hidden_states + residual
|
|
282
|
+
|
|
283
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
|
284
|
+
|
|
285
|
+
return hidden_states
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from diffusers.models._modeling_parallel import (
|
|
7
|
+
ContextParallelInput,
|
|
8
|
+
ContextParallelOutput,
|
|
9
|
+
ContextParallelModelPlan,
|
|
10
|
+
)
|
|
11
|
+
except ImportError:
|
|
12
|
+
raise ImportError(
|
|
13
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
14
|
+
"Please install latest version of diffusers from source: \n"
|
|
15
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
16
|
+
)
|
|
17
|
+
from .cp_plan_registers import (
|
|
18
|
+
ContextParallelismPlanner,
|
|
19
|
+
ContextParallelismPlannerRegister,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from cache_dit.logger import init_logger
|
|
23
|
+
|
|
24
|
+
logger = init_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@ContextParallelismPlannerRegister.register("QwenImage")
|
|
28
|
+
class QwenImageContextParallelismPlanner(ContextParallelismPlanner):
|
|
29
|
+
def apply(
|
|
30
|
+
self,
|
|
31
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
32
|
+
**kwargs,
|
|
33
|
+
) -> ContextParallelModelPlan:
|
|
34
|
+
|
|
35
|
+
# NOTE: Set it as False to use custom CP plan defined here.
|
|
36
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
37
|
+
|
|
38
|
+
if (
|
|
39
|
+
transformer is not None
|
|
40
|
+
and self._cp_planner_preferred_native_diffusers
|
|
41
|
+
):
|
|
42
|
+
from diffusers import QwenImageTransformer2DModel
|
|
43
|
+
|
|
44
|
+
assert isinstance(
|
|
45
|
+
transformer, QwenImageTransformer2DModel
|
|
46
|
+
), "Transformer must be an instance of QwenImageTransformer2DModel"
|
|
47
|
+
if hasattr(transformer, "_cp_plan"):
|
|
48
|
+
if transformer._cp_plan is not None:
|
|
49
|
+
return transformer._cp_plan
|
|
50
|
+
|
|
51
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
52
|
+
# a little different from the native diffusers implementation
|
|
53
|
+
# for some models.
|
|
54
|
+
_cp_plan = {
|
|
55
|
+
# Here is a Transformer level CP plan for Flux, which will
|
|
56
|
+
# only apply the only 1 split hook (pre_forward) on the forward
|
|
57
|
+
# of Transformer, and gather the output after Transformer forward.
|
|
58
|
+
# Pattern of transformer forward, split_output=False:
|
|
59
|
+
# un-split input -> splited input (inside transformer)
|
|
60
|
+
# Pattern of the transformer_blocks, single_transformer_blocks:
|
|
61
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
62
|
+
# -> all2all
|
|
63
|
+
# -> attn (local head, full seqlen)
|
|
64
|
+
# -> all2all
|
|
65
|
+
# -> splited output
|
|
66
|
+
# The `hidden_states` and `encoder_hidden_states` will still keep
|
|
67
|
+
# itself splited after block forward (namely, automatic split by
|
|
68
|
+
# the all2all comm op after attn) for the all blocks.
|
|
69
|
+
"": {
|
|
70
|
+
"hidden_states": ContextParallelInput(
|
|
71
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
72
|
+
),
|
|
73
|
+
# NOTE: Due to the joint attention implementation of
|
|
74
|
+
# QwenImageTransformerBlock, we must split the
|
|
75
|
+
# encoder_hidden_states as well.
|
|
76
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
77
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
78
|
+
),
|
|
79
|
+
# NOTE: But encoder_hidden_states_mask seems never used in
|
|
80
|
+
# QwenImageTransformerBlock, so we do not split it here.
|
|
81
|
+
# "encoder_hidden_states_mask": ContextParallelInput(
|
|
82
|
+
# split_dim=1, expected_dims=2, split_output=False
|
|
83
|
+
# ),
|
|
84
|
+
},
|
|
85
|
+
# Pattern of pos_embed, split_output=True (split output rather than input):
|
|
86
|
+
# un-split input
|
|
87
|
+
# -> keep input un-split
|
|
88
|
+
# -> rope
|
|
89
|
+
# -> splited output
|
|
90
|
+
"pos_embed": {
|
|
91
|
+
0: ContextParallelInput(
|
|
92
|
+
split_dim=0, expected_dims=2, split_output=True
|
|
93
|
+
),
|
|
94
|
+
1: ContextParallelInput(
|
|
95
|
+
split_dim=0, expected_dims=2, split_output=True
|
|
96
|
+
),
|
|
97
|
+
},
|
|
98
|
+
# Then, the final proj_out will gather the splited output.
|
|
99
|
+
# splited input (previous splited output)
|
|
100
|
+
# -> all gather
|
|
101
|
+
# -> un-split output
|
|
102
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
103
|
+
}
|
|
104
|
+
return _cp_plan
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import logging
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from diffusers.models._modeling_parallel import (
|
|
9
|
+
ContextParallelModelPlan,
|
|
10
|
+
)
|
|
11
|
+
except ImportError:
|
|
12
|
+
raise ImportError(
|
|
13
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
14
|
+
"Please install latest version of diffusers from source: \n"
|
|
15
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from cache_dit.logger import init_logger
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"ContextParallelismPlanner",
|
|
25
|
+
"ContextParallelismPlannerRegister",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ContextParallelismPlanner:
|
|
30
|
+
# Prefer native diffusers implementation if available
|
|
31
|
+
_cp_planner_preferred_native_diffusers: bool = True
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def apply(
|
|
35
|
+
self,
|
|
36
|
+
# NOTE: Keep this kwarg for future extensions
|
|
37
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
38
|
+
**kwargs,
|
|
39
|
+
) -> ContextParallelModelPlan:
|
|
40
|
+
# NOTE: This method should only return the CP plan dictionary.
|
|
41
|
+
raise NotImplementedError(
|
|
42
|
+
"apply method must be implemented by subclasses"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ContextParallelismPlannerRegister:
|
|
47
|
+
_cp_planner_registry: dict[str, ContextParallelismPlanner] = {}
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def register(cls, name: str):
|
|
51
|
+
def decorator(planner_cls: type[ContextParallelismPlanner]):
|
|
52
|
+
assert (
|
|
53
|
+
name not in cls._cp_planner_registry
|
|
54
|
+
), f"ContextParallelismPlanner with name {name} is already registered."
|
|
55
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
56
|
+
logger.debug(f"Registering ContextParallelismPlanner: {name}")
|
|
57
|
+
cls._cp_planner_registry[name] = planner_cls
|
|
58
|
+
return planner_cls
|
|
59
|
+
|
|
60
|
+
return decorator
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def get_planner(
|
|
64
|
+
cls, transformer: str | torch.nn.Module | ModelMixin
|
|
65
|
+
) -> type[ContextParallelismPlanner]:
|
|
66
|
+
if isinstance(transformer, (torch.nn.Module, ModelMixin)):
|
|
67
|
+
name = transformer.__class__.__name__
|
|
68
|
+
else:
|
|
69
|
+
name = transformer
|
|
70
|
+
planner_cls = None
|
|
71
|
+
for planner_name in cls._cp_planner_registry:
|
|
72
|
+
if name.startswith(planner_name):
|
|
73
|
+
planner_cls = cls._cp_planner_registry.get(planner_name)
|
|
74
|
+
break
|
|
75
|
+
if planner_cls is None:
|
|
76
|
+
raise ValueError(f"No planner registered under name: {name}")
|
|
77
|
+
return planner_cls
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def supported_planners(
|
|
81
|
+
cls,
|
|
82
|
+
) -> tuple[int, list[str]]:
|
|
83
|
+
val_planners = cls._cp_planner_registry.keys()
|
|
84
|
+
return len(val_planners), [p for p in val_planners]
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from diffusers.models._modeling_parallel import (
|
|
7
|
+
ContextParallelInput,
|
|
8
|
+
ContextParallelOutput,
|
|
9
|
+
ContextParallelModelPlan,
|
|
10
|
+
)
|
|
11
|
+
except ImportError:
|
|
12
|
+
raise ImportError(
|
|
13
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
14
|
+
"Please install latest version of diffusers from source: \n"
|
|
15
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
16
|
+
)
|
|
17
|
+
from .cp_plan_registers import (
|
|
18
|
+
ContextParallelismPlanner,
|
|
19
|
+
ContextParallelismPlannerRegister,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from cache_dit.logger import init_logger
|
|
23
|
+
|
|
24
|
+
logger = init_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# TODO: Add WanVACETransformer3DModel context parallelism planner.
|
|
28
|
+
# NOTE: Maybe use full name to avoid name conflict between
|
|
29
|
+
# WanTransformer3DModel and WanVACETransformer3DModel?
|
|
30
|
+
@ContextParallelismPlannerRegister.register("WanTransformer3D")
|
|
31
|
+
class WanContextParallelismPlanner(ContextParallelismPlanner):
|
|
32
|
+
def apply(
|
|
33
|
+
self,
|
|
34
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
) -> ContextParallelModelPlan:
|
|
37
|
+
if (
|
|
38
|
+
transformer is not None
|
|
39
|
+
and self._cp_planner_preferred_native_diffusers
|
|
40
|
+
):
|
|
41
|
+
from diffusers import WanTransformer3DModel
|
|
42
|
+
|
|
43
|
+
assert isinstance(
|
|
44
|
+
transformer, WanTransformer3DModel
|
|
45
|
+
), "Transformer must be an instance of WanTransformer3DModel"
|
|
46
|
+
if hasattr(transformer, "_cp_plan"):
|
|
47
|
+
if transformer._cp_plan is not None:
|
|
48
|
+
return transformer._cp_plan
|
|
49
|
+
|
|
50
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
51
|
+
# a little different from the native diffusers implementation
|
|
52
|
+
# for some models.
|
|
53
|
+
_cp_plan = {
|
|
54
|
+
# Pattern of rope, split_output=True (split output rather than input):
|
|
55
|
+
# un-split input
|
|
56
|
+
# -> keep input un-split
|
|
57
|
+
# -> rope
|
|
58
|
+
# -> splited output
|
|
59
|
+
"rope": {
|
|
60
|
+
0: ContextParallelInput(
|
|
61
|
+
split_dim=1, expected_dims=4, split_output=True
|
|
62
|
+
),
|
|
63
|
+
1: ContextParallelInput(
|
|
64
|
+
split_dim=1, expected_dims=4, split_output=True
|
|
65
|
+
),
|
|
66
|
+
},
|
|
67
|
+
# Pattern of blocks.0, split_output=False:
|
|
68
|
+
# un-split input -> split -> to_qkv/...
|
|
69
|
+
# -> all2all
|
|
70
|
+
# -> attn (local head, full seqlen)
|
|
71
|
+
# -> all2all
|
|
72
|
+
# -> splited output
|
|
73
|
+
# (only split hidden_states, not encoder_hidden_states)
|
|
74
|
+
"blocks.0": {
|
|
75
|
+
"hidden_states": ContextParallelInput(
|
|
76
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
77
|
+
),
|
|
78
|
+
},
|
|
79
|
+
# Pattern of the all blocks, split_output=False:
|
|
80
|
+
# un-split input -> split -> to_qkv/...
|
|
81
|
+
# -> all2all
|
|
82
|
+
# -> attn (local head, full seqlen)
|
|
83
|
+
# -> all2all
|
|
84
|
+
# -> splited output
|
|
85
|
+
# (only split encoder_hidden_states, not hidden_states.
|
|
86
|
+
# hidden_states has been automatically split in previous
|
|
87
|
+
# block by all2all comm op after attn)
|
|
88
|
+
# The `encoder_hidden_states` will [NOT] be changed after each block forward,
|
|
89
|
+
# so we need to split it at [ALL] block by the inserted split hook.
|
|
90
|
+
"blocks.*": {
|
|
91
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
92
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
93
|
+
),
|
|
94
|
+
},
|
|
95
|
+
# Then, the final proj_out will gather the splited output.
|
|
96
|
+
# splited input (previous splited output)
|
|
97
|
+
# -> all gather
|
|
98
|
+
# -> un-split output
|
|
99
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
100
|
+
}
|
|
101
|
+
return _cp_plan
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Docstring references: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/_modeling_parallel.py#L185
|
|
2
|
+
# A dictionary where keys denote the input to be split across context parallel region, and the
|
|
3
|
+
# value denotes the sharding configuration.
|
|
4
|
+
# If the key is a string, it denotes the name of the parameter in the forward function.
|
|
5
|
+
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
|
|
6
|
+
# to be split across context parallel region.
|
|
7
|
+
# ContextParallelInputType = Dict[
|
|
8
|
+
# Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
|
|
9
|
+
# ]
|
|
10
|
+
|
|
11
|
+
# A dictionary where keys denote the output to be gathered across context parallel region, and the
|
|
12
|
+
# value denotes the gathering configuration.
|
|
13
|
+
# ContextParallelOutputType = Union[
|
|
14
|
+
# ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
|
|
15
|
+
# ]
|
|
16
|
+
|
|
17
|
+
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
|
|
18
|
+
# the module should be split/gathered across context parallel region.
|
|
19
|
+
# ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
|
|
20
|
+
|
|
21
|
+
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
|
|
22
|
+
#
|
|
23
|
+
# Each model should define a _cp_plan attribute that contains information on how to shard/gather
|
|
24
|
+
# tensors at different stages of the forward:
|
|
25
|
+
#
|
|
26
|
+
# ```python
|
|
27
|
+
# _cp_plan = {
|
|
28
|
+
# "": {
|
|
29
|
+
# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
|
30
|
+
# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
|
31
|
+
# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
|
32
|
+
# },
|
|
33
|
+
# "pos_embed": {
|
|
34
|
+
# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
|
35
|
+
# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
|
36
|
+
# },
|
|
37
|
+
# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
38
|
+
# }
|
|
39
|
+
# ```
|
|
40
|
+
#
|
|
41
|
+
# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
|
|
42
|
+
# split/gathered according to this at the respective module level. Here, the following happens:
|
|
43
|
+
# - "":
|
|
44
|
+
# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
|
|
45
|
+
# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
|
|
46
|
+
# - "pos_embed":
|
|
47
|
+
# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
|
|
48
|
+
# we can individually specify how they should be split
|
|
49
|
+
# - "proj_out":
|
|
50
|
+
# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
|
|
51
|
+
# layer forward has run).
|
|
52
|
+
#
|
|
53
|
+
# ContextParallelInput:
|
|
54
|
+
# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
|
|
55
|
+
#
|
|
56
|
+
# ContextParallelOutput:
|
|
57
|
+
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
|
|
58
|
+
|
|
59
|
+
from .cp_plan_registers import (
|
|
60
|
+
ContextParallelismPlanner,
|
|
61
|
+
ContextParallelismPlannerRegister,
|
|
62
|
+
)
|
|
63
|
+
from .cp_plan_flux import FluxContextParallelismPlanner
|
|
64
|
+
from .cp_plan_qwen_image import QwenImageContextParallelismPlanner
|
|
65
|
+
from .cp_plan_wan import WanContextParallelismPlanner
|
|
66
|
+
from .cp_plan_ltxvideo import LTXVideoContextParallelismPlanner
|
|
67
|
+
from .cp_plan_hunyuan import HunyuanImageContextParallelismPlanner
|
|
68
|
+
from .cp_plan_hunyuan import HunyuanVideoContextParallelismPlanner
|
|
69
|
+
from .cp_plan_cogvideox import CogVideoXContextParallelismPlanner
|
|
70
|
+
from .cp_plan_cogview import CogView3PlusContextParallelismPlanner
|
|
71
|
+
from .cp_plan_cogview import CogView4ContextParallelismPlanner
|
|
72
|
+
from .cp_plan_cosisid import CosisIDContextParallelismPlanner
|
|
73
|
+
from .cp_plan_chroma import ChromaContextParallelismPlanner
|
|
74
|
+
from .cp_plan_pixart import PixArtContextParallelismPlanner
|
|
75
|
+
from .cp_plan_dit import DiTContextParallelismPlanner
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
import nunchaku # noqa: F401
|
|
79
|
+
|
|
80
|
+
_nunchaku_available = True
|
|
81
|
+
except ImportError:
|
|
82
|
+
_nunchaku_available = False
|
|
83
|
+
|
|
84
|
+
if _nunchaku_available:
|
|
85
|
+
from .cp_plan_nunchaku import ( # noqa: F401
|
|
86
|
+
NunchakuFluxContextParallelismPlanner,
|
|
87
|
+
)
|
|
88
|
+
from .cp_plan_nunchaku import ( # noqa: F401
|
|
89
|
+
NunchakuQwenImageContextParallelismPlanner,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
__all__ = [
|
|
94
|
+
"ContextParallelismPlanner",
|
|
95
|
+
"ContextParallelismPlannerRegister",
|
|
96
|
+
"FluxContextParallelismPlanner",
|
|
97
|
+
"QwenImageContextParallelismPlanner",
|
|
98
|
+
"WanContextParallelismPlanner",
|
|
99
|
+
"LTXVideoContextParallelismPlanner",
|
|
100
|
+
"HunyuanImageContextParallelismPlanner",
|
|
101
|
+
"HunyuanVideoContextParallelismPlanner",
|
|
102
|
+
"CogVideoXContextParallelismPlanner",
|
|
103
|
+
"CogView3PlusContextParallelismPlanner",
|
|
104
|
+
"CogView4ContextParallelismPlanner",
|
|
105
|
+
"CosisIDContextParallelismPlanner",
|
|
106
|
+
"ChromaContextParallelismPlanner",
|
|
107
|
+
"PixArtContextParallelismPlanner",
|
|
108
|
+
"DiTContextParallelismPlanner",
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
if _nunchaku_available:
|
|
112
|
+
__all__.extend(
|
|
113
|
+
[
|
|
114
|
+
"NunchakuFluxContextParallelismPlanner",
|
|
115
|
+
"NunchakuQwenImageContextParallelismPlanner",
|
|
116
|
+
]
|
|
117
|
+
)
|