cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import functools
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
5
|
+
from diffusers.models.transformers.cogvideox_transformer_3d import (
|
|
6
|
+
CogVideoXAttnProcessor2_0,
|
|
7
|
+
CogVideoXTransformer3DModel,
|
|
8
|
+
)
|
|
9
|
+
from diffusers.models.attention_processor import Attention
|
|
10
|
+
from diffusers.models.attention_dispatch import dispatch_attention_fn
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from diffusers.models._modeling_parallel import (
|
|
14
|
+
ContextParallelInput,
|
|
15
|
+
ContextParallelOutput,
|
|
16
|
+
ContextParallelModelPlan,
|
|
17
|
+
)
|
|
18
|
+
except ImportError:
|
|
19
|
+
raise ImportError(
|
|
20
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
21
|
+
"Please install latest version of diffusers from source: \n"
|
|
22
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
23
|
+
)
|
|
24
|
+
from .cp_plan_registers import (
|
|
25
|
+
ContextParallelismPlanner,
|
|
26
|
+
ContextParallelismPlannerRegister,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from cache_dit.logger import init_logger
|
|
30
|
+
|
|
31
|
+
logger = init_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@ContextParallelismPlannerRegister.register("CogVideoX")
|
|
35
|
+
class CogVideoXContextParallelismPlanner(ContextParallelismPlanner):
|
|
36
|
+
def apply(
|
|
37
|
+
self,
|
|
38
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
39
|
+
**kwargs,
|
|
40
|
+
) -> ContextParallelModelPlan:
|
|
41
|
+
|
|
42
|
+
# NOTE: Diffusers native CP plan still not supported
|
|
43
|
+
# for CogVideoX now.
|
|
44
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
45
|
+
|
|
46
|
+
if (
|
|
47
|
+
transformer is not None
|
|
48
|
+
and self._cp_planner_preferred_native_diffusers
|
|
49
|
+
):
|
|
50
|
+
assert isinstance(
|
|
51
|
+
transformer, CogVideoXTransformer3DModel
|
|
52
|
+
), "Transformer must be an instance of CogVideoXTransformer3DModel"
|
|
53
|
+
if hasattr(transformer, "_cp_plan"):
|
|
54
|
+
if transformer._cp_plan is not None:
|
|
55
|
+
return transformer._cp_plan
|
|
56
|
+
|
|
57
|
+
CogVideoXAttnProcessor2_0.__call__ = (
|
|
58
|
+
__patch_CogVideoXAttnProcessor2_0__call__
|
|
59
|
+
)
|
|
60
|
+
# Also need to patch the parallel config and attention backend
|
|
61
|
+
if not hasattr(CogVideoXAttnProcessor2_0, "_parallel_config"):
|
|
62
|
+
CogVideoXAttnProcessor2_0._parallel_config = None
|
|
63
|
+
if not hasattr(CogVideoXAttnProcessor2_0, "_attention_backend"):
|
|
64
|
+
CogVideoXAttnProcessor2_0._attention_backend = None
|
|
65
|
+
|
|
66
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
67
|
+
# a little different from the native diffusers implementation
|
|
68
|
+
# for some models.
|
|
69
|
+
_cp_plan = {
|
|
70
|
+
# Pattern of transformer_blocks.0, split_output=False:
|
|
71
|
+
# un-split input -> split -> to_qkv/...
|
|
72
|
+
# -> all2all
|
|
73
|
+
# -> attn (local head, full seqlen)
|
|
74
|
+
# -> all2all
|
|
75
|
+
# -> splited output
|
|
76
|
+
# Pattern of the rest transformer_blocks, split_output=False:
|
|
77
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
78
|
+
# -> all2all
|
|
79
|
+
# -> attn (local head, full seqlen)
|
|
80
|
+
# -> all2all
|
|
81
|
+
# -> splited output
|
|
82
|
+
# The `encoder_hidden_states` will be changed after each block forward,
|
|
83
|
+
# so we need to split it at the first block, and keep it splited (namely,
|
|
84
|
+
# automatically split by the all2all op after attn) for the rest blocks.
|
|
85
|
+
# The `out` tensor of local attn will be splited into `hidden_states` and
|
|
86
|
+
# `encoder_hidden_states` after each block forward, thus both of them
|
|
87
|
+
# will be automatically splited by all2all comm op after local attn.
|
|
88
|
+
"transformer_blocks.0": {
|
|
89
|
+
"hidden_states": ContextParallelInput(
|
|
90
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
91
|
+
),
|
|
92
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
93
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
94
|
+
),
|
|
95
|
+
},
|
|
96
|
+
# Pattern of the image_rotary_emb, split at every block, because the it
|
|
97
|
+
# is not automatically splited by all2all comm op and keep un-splited
|
|
98
|
+
# while the block forward finished:
|
|
99
|
+
# un-split input -> split output
|
|
100
|
+
# -> after block forward
|
|
101
|
+
# -> un-split input
|
|
102
|
+
# un-split input -> split output
|
|
103
|
+
# ...
|
|
104
|
+
"transformer_blocks.*": {
|
|
105
|
+
"image_rotary_emb": [
|
|
106
|
+
ContextParallelInput(
|
|
107
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
108
|
+
),
|
|
109
|
+
ContextParallelInput(
|
|
110
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
111
|
+
),
|
|
112
|
+
],
|
|
113
|
+
},
|
|
114
|
+
# transformer forward while using CP, since it is not splited here.
|
|
115
|
+
# Then, the final proj_out will gather the splited output.
|
|
116
|
+
# splited input (previous splited output)
|
|
117
|
+
# -> all gather
|
|
118
|
+
# -> un-split output
|
|
119
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
120
|
+
}
|
|
121
|
+
return _cp_plan
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@functools.wraps(CogVideoXAttnProcessor2_0.__call__)
|
|
125
|
+
def __patch_CogVideoXAttnProcessor2_0__call__(
|
|
126
|
+
self: CogVideoXAttnProcessor2_0,
|
|
127
|
+
attn: Attention,
|
|
128
|
+
hidden_states: torch.Tensor,
|
|
129
|
+
encoder_hidden_states: torch.Tensor,
|
|
130
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
131
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
132
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
133
|
+
text_seq_length = encoder_hidden_states.size(1)
|
|
134
|
+
|
|
135
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
136
|
+
|
|
137
|
+
batch_size, sequence_length, _ = hidden_states.shape
|
|
138
|
+
|
|
139
|
+
# NOTE(DefTruth): attention mask is always None in CogVideoX
|
|
140
|
+
if attention_mask is not None:
|
|
141
|
+
attention_mask = attn.prepare_attention_mask(
|
|
142
|
+
attention_mask, sequence_length, batch_size
|
|
143
|
+
)
|
|
144
|
+
attention_mask = attention_mask.view(
|
|
145
|
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
query = attn.to_q(hidden_states)
|
|
149
|
+
key = attn.to_k(hidden_states)
|
|
150
|
+
value = attn.to_v(hidden_states)
|
|
151
|
+
|
|
152
|
+
inner_dim = key.shape[-1]
|
|
153
|
+
head_dim = inner_dim // attn.heads
|
|
154
|
+
|
|
155
|
+
# NOTE(DefTruth): no transpose
|
|
156
|
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
|
157
|
+
key = key.view(batch_size, -1, attn.heads, head_dim)
|
|
158
|
+
value = value.view(batch_size, -1, attn.heads, head_dim)
|
|
159
|
+
|
|
160
|
+
if attn.norm_q is not None:
|
|
161
|
+
query = attn.norm_q(query)
|
|
162
|
+
if attn.norm_k is not None:
|
|
163
|
+
key = attn.norm_k(key)
|
|
164
|
+
|
|
165
|
+
# Apply RoPE if needed
|
|
166
|
+
if image_rotary_emb is not None:
|
|
167
|
+
from diffusers.models.embeddings import apply_rotary_emb
|
|
168
|
+
|
|
169
|
+
query[:, text_seq_length:] = apply_rotary_emb(
|
|
170
|
+
query[:, text_seq_length:],
|
|
171
|
+
image_rotary_emb,
|
|
172
|
+
sequence_dim=1,
|
|
173
|
+
)
|
|
174
|
+
if not attn.is_cross_attention:
|
|
175
|
+
key[:, text_seq_length:] = apply_rotary_emb(
|
|
176
|
+
key[:, text_seq_length:],
|
|
177
|
+
image_rotary_emb,
|
|
178
|
+
sequence_dim=1,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# NOTE(DefTruth): Apply dispatch_attention_fn instead of sdpa directly
|
|
182
|
+
hidden_states = dispatch_attention_fn(
|
|
183
|
+
query,
|
|
184
|
+
key,
|
|
185
|
+
value,
|
|
186
|
+
attn_mask=attention_mask,
|
|
187
|
+
dropout_p=0.0,
|
|
188
|
+
is_causal=False,
|
|
189
|
+
backend=getattr(self, "_attention_backend", None),
|
|
190
|
+
parallel_config=getattr(self, "_parallel_config", None),
|
|
191
|
+
)
|
|
192
|
+
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
|
|
193
|
+
|
|
194
|
+
# linear proj
|
|
195
|
+
hidden_states = attn.to_out[0](hidden_states)
|
|
196
|
+
# dropout
|
|
197
|
+
hidden_states = attn.to_out[1](hidden_states)
|
|
198
|
+
|
|
199
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
200
|
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
|
201
|
+
)
|
|
202
|
+
return hidden_states, encoder_hidden_states
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import functools
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
5
|
+
from diffusers.models.transformers.transformer_cogview3plus import (
|
|
6
|
+
CogView3PlusTransformer2DModel,
|
|
7
|
+
CogVideoXAttnProcessor2_0,
|
|
8
|
+
)
|
|
9
|
+
from diffusers.models.transformers.transformer_cogview4 import (
|
|
10
|
+
CogView4Transformer2DModel,
|
|
11
|
+
CogView4AttnProcessor,
|
|
12
|
+
)
|
|
13
|
+
from diffusers.models.attention_processor import Attention
|
|
14
|
+
from diffusers.models.attention_dispatch import dispatch_attention_fn
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from diffusers.models._modeling_parallel import (
|
|
18
|
+
ContextParallelInput,
|
|
19
|
+
ContextParallelOutput,
|
|
20
|
+
ContextParallelModelPlan,
|
|
21
|
+
)
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
25
|
+
"Please install latest version of diffusers from source: \n"
|
|
26
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
27
|
+
)
|
|
28
|
+
from .cp_plan_registers import (
|
|
29
|
+
ContextParallelismPlanner,
|
|
30
|
+
ContextParallelismPlannerRegister,
|
|
31
|
+
)
|
|
32
|
+
from .cp_plan_cogvideox import __patch_CogVideoXAttnProcessor2_0__call__
|
|
33
|
+
|
|
34
|
+
from cache_dit.logger import init_logger
|
|
35
|
+
|
|
36
|
+
logger = init_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@ContextParallelismPlannerRegister.register("CogView3Plus")
|
|
40
|
+
class CogView3PlusContextParallelismPlanner(ContextParallelismPlanner):
|
|
41
|
+
def apply(
|
|
42
|
+
self,
|
|
43
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
44
|
+
**kwargs,
|
|
45
|
+
) -> ContextParallelModelPlan:
|
|
46
|
+
|
|
47
|
+
# NOTE: Diffusers native CP plan still not supported
|
|
48
|
+
# for CogView3Plus now.
|
|
49
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
50
|
+
|
|
51
|
+
if (
|
|
52
|
+
transformer is not None
|
|
53
|
+
and self._cp_planner_preferred_native_diffusers
|
|
54
|
+
):
|
|
55
|
+
assert isinstance(
|
|
56
|
+
transformer, CogView3PlusTransformer2DModel
|
|
57
|
+
), "Transformer must be an instance of CogView3PlusTransformer2DModel"
|
|
58
|
+
if hasattr(transformer, "_cp_plan"):
|
|
59
|
+
if transformer._cp_plan is not None:
|
|
60
|
+
return transformer._cp_plan
|
|
61
|
+
|
|
62
|
+
# CogView3Plus and CogVideoX share the same attention processor
|
|
63
|
+
CogVideoXAttnProcessor2_0.__call__ = (
|
|
64
|
+
__patch_CogVideoXAttnProcessor2_0__call__
|
|
65
|
+
)
|
|
66
|
+
# Also need to patch the parallel config and attention backend
|
|
67
|
+
if not hasattr(CogVideoXAttnProcessor2_0, "_parallel_config"):
|
|
68
|
+
CogVideoXAttnProcessor2_0._parallel_config = None
|
|
69
|
+
if not hasattr(CogVideoXAttnProcessor2_0, "_attention_backend"):
|
|
70
|
+
CogVideoXAttnProcessor2_0._attention_backend = None
|
|
71
|
+
|
|
72
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
73
|
+
# a little different from the native diffusers implementation
|
|
74
|
+
# for some models.
|
|
75
|
+
_cp_plan = {
|
|
76
|
+
# Pattern of transformer_blocks.0, split_output=False:
|
|
77
|
+
# un-split input -> split -> to_qkv/...
|
|
78
|
+
# -> all2all
|
|
79
|
+
# -> attn (local head, full seqlen)
|
|
80
|
+
# -> all2all
|
|
81
|
+
# -> splited output
|
|
82
|
+
# Pattern of the rest transformer_blocks, split_output=False:
|
|
83
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
84
|
+
# -> all2all
|
|
85
|
+
# -> attn (local head, full seqlen)
|
|
86
|
+
# -> all2all
|
|
87
|
+
# -> splited output
|
|
88
|
+
# The `encoder_hidden_states` will be changed after each block forward,
|
|
89
|
+
# so we need to split it at the first block, and keep it splited (namely,
|
|
90
|
+
# automatically split by the all2all op after attn) for the rest blocks.
|
|
91
|
+
# The `out` tensor of local attn will be splited into `hidden_states` and
|
|
92
|
+
# `encoder_hidden_states` after each block forward, thus both of them
|
|
93
|
+
# will be automatically splited by all2all comm op after local attn.
|
|
94
|
+
"transformer_blocks.0": {
|
|
95
|
+
"hidden_states": ContextParallelInput(
|
|
96
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
97
|
+
),
|
|
98
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
99
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
100
|
+
),
|
|
101
|
+
},
|
|
102
|
+
# transformer forward while using CP, since it is not splited here.
|
|
103
|
+
# Then, the final proj_out will gather the splited output.
|
|
104
|
+
# splited input (previous splited output)
|
|
105
|
+
# -> all gather
|
|
106
|
+
# -> un-split output
|
|
107
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
108
|
+
}
|
|
109
|
+
return _cp_plan
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@ContextParallelismPlannerRegister.register("CogView4")
|
|
113
|
+
class CogView4ContextParallelismPlanner(ContextParallelismPlanner):
|
|
114
|
+
def apply(
|
|
115
|
+
self,
|
|
116
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
117
|
+
**kwargs,
|
|
118
|
+
) -> ContextParallelModelPlan:
|
|
119
|
+
|
|
120
|
+
# NOTE: Diffusers native CP plan still not supported
|
|
121
|
+
# for CogView4 now.
|
|
122
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
123
|
+
|
|
124
|
+
if (
|
|
125
|
+
transformer is not None
|
|
126
|
+
and self._cp_planner_preferred_native_diffusers
|
|
127
|
+
):
|
|
128
|
+
assert isinstance(
|
|
129
|
+
transformer, CogView4Transformer2DModel
|
|
130
|
+
), "Transformer must be an instance of CogView4Transformer2DModel"
|
|
131
|
+
if hasattr(transformer, "_cp_plan"):
|
|
132
|
+
if transformer._cp_plan is not None:
|
|
133
|
+
return transformer._cp_plan
|
|
134
|
+
|
|
135
|
+
CogView4AttnProcessor.__call__ = __patch_CogView4AttnProcessor__call__
|
|
136
|
+
# Also need to patch the parallel config and attention backend
|
|
137
|
+
if not hasattr(CogView4AttnProcessor, "_parallel_config"):
|
|
138
|
+
CogView4AttnProcessor._parallel_config = None
|
|
139
|
+
if not hasattr(CogView4AttnProcessor, "_attention_backend"):
|
|
140
|
+
CogView4AttnProcessor._attention_backend = None
|
|
141
|
+
|
|
142
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
143
|
+
# a little different from the native diffusers implementation
|
|
144
|
+
# for some models.
|
|
145
|
+
_cp_plan = {
|
|
146
|
+
# Pattern of transformer_blocks.0, split_output=False:
|
|
147
|
+
# un-split input -> split -> to_qkv/...
|
|
148
|
+
# -> all2all
|
|
149
|
+
# -> attn (local head, full seqlen)
|
|
150
|
+
# -> all2all
|
|
151
|
+
# -> splited output
|
|
152
|
+
# Pattern of the rest transformer_blocks, split_output=False:
|
|
153
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
154
|
+
# -> all2all
|
|
155
|
+
# -> attn (local head, full seqlen)
|
|
156
|
+
# -> all2all
|
|
157
|
+
# -> splited output
|
|
158
|
+
# The `encoder_hidden_states` will be changed after each block forward,
|
|
159
|
+
# so we need to split it at the first block, and keep it splited (namely,
|
|
160
|
+
# automatically split by the all2all op after attn) for the rest blocks.
|
|
161
|
+
# The `out` tensor of local attn will be splited into `hidden_states` and
|
|
162
|
+
# `encoder_hidden_states` after each block forward, thus both of them
|
|
163
|
+
# will be automatically splited by all2all comm op after local attn.
|
|
164
|
+
"transformer_blocks.0": {
|
|
165
|
+
"hidden_states": ContextParallelInput(
|
|
166
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
167
|
+
),
|
|
168
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
169
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
170
|
+
),
|
|
171
|
+
},
|
|
172
|
+
# Pattern of the image_rotary_emb, split at every block, because the it
|
|
173
|
+
# is not automatically splited by all2all comm op and keep un-splited
|
|
174
|
+
# while the block forward finished:
|
|
175
|
+
# un-split input -> split output
|
|
176
|
+
# -> after block forward
|
|
177
|
+
# -> un-split input
|
|
178
|
+
# un-split input -> split output
|
|
179
|
+
# ...
|
|
180
|
+
"transformer_blocks.*": {
|
|
181
|
+
"image_rotary_emb": [
|
|
182
|
+
ContextParallelInput(
|
|
183
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
184
|
+
),
|
|
185
|
+
ContextParallelInput(
|
|
186
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
187
|
+
),
|
|
188
|
+
],
|
|
189
|
+
},
|
|
190
|
+
# transformer forward while using CP, since it is not splited here.
|
|
191
|
+
# Then, the final proj_out will gather the splited output.
|
|
192
|
+
# splited input (previous splited output)
|
|
193
|
+
# -> all gather
|
|
194
|
+
# -> un-split output
|
|
195
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
196
|
+
}
|
|
197
|
+
return _cp_plan
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@functools.wraps(CogView4AttnProcessor.__call__)
|
|
201
|
+
def __patch_CogView4AttnProcessor__call__(
|
|
202
|
+
self: CogView4AttnProcessor,
|
|
203
|
+
attn: Attention,
|
|
204
|
+
hidden_states: torch.Tensor,
|
|
205
|
+
encoder_hidden_states: torch.Tensor,
|
|
206
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
207
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
208
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
209
|
+
dtype = encoder_hidden_states.dtype
|
|
210
|
+
|
|
211
|
+
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
|
212
|
+
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
|
213
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
214
|
+
|
|
215
|
+
# 1. QKV projections
|
|
216
|
+
query = attn.to_q(hidden_states)
|
|
217
|
+
key = attn.to_k(hidden_states)
|
|
218
|
+
value = attn.to_v(hidden_states)
|
|
219
|
+
|
|
220
|
+
# NOTE(DefTruth): no transpose
|
|
221
|
+
query = query.unflatten(2, (attn.heads, -1))
|
|
222
|
+
key = key.unflatten(2, (attn.heads, -1))
|
|
223
|
+
value = value.unflatten(2, (attn.heads, -1))
|
|
224
|
+
|
|
225
|
+
# 2. QK normalization
|
|
226
|
+
if attn.norm_q is not None:
|
|
227
|
+
query = attn.norm_q(query).to(dtype=dtype)
|
|
228
|
+
if attn.norm_k is not None:
|
|
229
|
+
key = attn.norm_k(key).to(dtype=dtype)
|
|
230
|
+
|
|
231
|
+
# 3. Rotational positional embeddings applied to latent stream
|
|
232
|
+
if image_rotary_emb is not None:
|
|
233
|
+
from diffusers.models.embeddings import apply_rotary_emb
|
|
234
|
+
|
|
235
|
+
query[:, text_seq_length:] = apply_rotary_emb(
|
|
236
|
+
query[:, text_seq_length:],
|
|
237
|
+
image_rotary_emb,
|
|
238
|
+
use_real_unbind_dim=-2,
|
|
239
|
+
sequence_dim=1,
|
|
240
|
+
)
|
|
241
|
+
key[:, text_seq_length:] = apply_rotary_emb(
|
|
242
|
+
key[:, text_seq_length:],
|
|
243
|
+
image_rotary_emb,
|
|
244
|
+
use_real_unbind_dim=-2,
|
|
245
|
+
sequence_dim=1,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# 4. Attention
|
|
249
|
+
if attention_mask is not None:
|
|
250
|
+
text_attn_mask = attention_mask
|
|
251
|
+
assert (
|
|
252
|
+
text_attn_mask.dim() == 2
|
|
253
|
+
), "the shape of text_attn_mask should be (batch_size, text_seq_length)"
|
|
254
|
+
text_attn_mask = text_attn_mask.float().to(query.device)
|
|
255
|
+
mix_attn_mask = torch.ones(
|
|
256
|
+
(batch_size, text_seq_length + image_seq_length),
|
|
257
|
+
device=query.device,
|
|
258
|
+
)
|
|
259
|
+
mix_attn_mask[:, :text_seq_length] = text_attn_mask # [B, seq_len]
|
|
260
|
+
# TODO(DefTruth): Permute mix_attn_mask if context parallel is used.
|
|
261
|
+
# For example, if work size = 2: [E, H] -> [E_0, H_0, E_1, H_1]
|
|
262
|
+
mix_attn_mask = mix_attn_mask.unsqueeze(2) # [B, seq_len, 1]
|
|
263
|
+
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(
|
|
264
|
+
1, 2
|
|
265
|
+
) # [B, seq_len, seq_len]
|
|
266
|
+
attention_mask = (
|
|
267
|
+
(attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
|
|
268
|
+
) # [B, 1, seq_len, seq_len]
|
|
269
|
+
if (
|
|
270
|
+
hasattr(self, "_parallel_config")
|
|
271
|
+
and self._parallel_config is not None
|
|
272
|
+
):
|
|
273
|
+
raise NotImplementedError(
|
|
274
|
+
"Attention mask with context parallelism for CogView4 "
|
|
275
|
+
"is not implemented yet."
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# NOTE(DefTruth): Apply dispatch_attention_fn instead of sdpa directly
|
|
279
|
+
hidden_states = dispatch_attention_fn(
|
|
280
|
+
query,
|
|
281
|
+
key,
|
|
282
|
+
value,
|
|
283
|
+
attn_mask=attention_mask,
|
|
284
|
+
dropout_p=0.0,
|
|
285
|
+
is_causal=False,
|
|
286
|
+
backend=getattr(self, "_attention_backend", None),
|
|
287
|
+
parallel_config=getattr(self, "_parallel_config", None),
|
|
288
|
+
)
|
|
289
|
+
hidden_states = hidden_states.flatten(2, 3)
|
|
290
|
+
hidden_states = hidden_states.type_as(query)
|
|
291
|
+
|
|
292
|
+
# 5. Output projection
|
|
293
|
+
hidden_states = attn.to_out[0](hidden_states)
|
|
294
|
+
hidden_states = attn.to_out[1](hidden_states)
|
|
295
|
+
|
|
296
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
297
|
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
|
298
|
+
)
|
|
299
|
+
return hidden_states, encoder_hidden_states
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
4
|
+
from diffusers.models.transformers.consisid_transformer_3d import (
|
|
5
|
+
ConsisIDTransformer3DModel,
|
|
6
|
+
)
|
|
7
|
+
from diffusers.models.transformers.cogvideox_transformer_3d import (
|
|
8
|
+
CogVideoXAttnProcessor2_0,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from diffusers.models._modeling_parallel import (
|
|
13
|
+
ContextParallelInput,
|
|
14
|
+
ContextParallelOutput,
|
|
15
|
+
ContextParallelModelPlan,
|
|
16
|
+
)
|
|
17
|
+
except ImportError:
|
|
18
|
+
raise ImportError(
|
|
19
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
20
|
+
"Please install latest version of diffusers from source: \n"
|
|
21
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
22
|
+
)
|
|
23
|
+
from .cp_plan_registers import (
|
|
24
|
+
ContextParallelismPlanner,
|
|
25
|
+
ContextParallelismPlannerRegister,
|
|
26
|
+
)
|
|
27
|
+
from .cp_plan_cogvideox import __patch_CogVideoXAttnProcessor2_0__call__
|
|
28
|
+
|
|
29
|
+
from cache_dit.logger import init_logger
|
|
30
|
+
|
|
31
|
+
logger = init_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@ContextParallelismPlannerRegister.register("ConsisID")
|
|
35
|
+
class CosisIDContextParallelismPlanner(ContextParallelismPlanner):
|
|
36
|
+
def apply(
|
|
37
|
+
self,
|
|
38
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
39
|
+
**kwargs,
|
|
40
|
+
) -> ContextParallelModelPlan:
|
|
41
|
+
|
|
42
|
+
# NOTE: Diffusers native CP plan still not supported
|
|
43
|
+
# for ConsisID now.
|
|
44
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
45
|
+
|
|
46
|
+
if (
|
|
47
|
+
transformer is not None
|
|
48
|
+
and self._cp_planner_preferred_native_diffusers
|
|
49
|
+
):
|
|
50
|
+
assert isinstance(
|
|
51
|
+
transformer, ConsisIDTransformer3DModel
|
|
52
|
+
), "Transformer must be an instance of ConsisIDTransformer3DModel"
|
|
53
|
+
if hasattr(transformer, "_cp_plan"):
|
|
54
|
+
if transformer._cp_plan is not None:
|
|
55
|
+
return transformer._cp_plan
|
|
56
|
+
|
|
57
|
+
# ConsisID uses the same attention processor as CogVideoX.
|
|
58
|
+
CogVideoXAttnProcessor2_0.__call__ = (
|
|
59
|
+
__patch_CogVideoXAttnProcessor2_0__call__
|
|
60
|
+
)
|
|
61
|
+
# Also need to patch the parallel config and attention backend
|
|
62
|
+
if not hasattr(CogVideoXAttnProcessor2_0, "_parallel_config"):
|
|
63
|
+
CogVideoXAttnProcessor2_0._parallel_config = None
|
|
64
|
+
if not hasattr(CogVideoXAttnProcessor2_0, "_attention_backend"):
|
|
65
|
+
CogVideoXAttnProcessor2_0._attention_backend = None
|
|
66
|
+
|
|
67
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
68
|
+
# a little different from the native diffusers implementation
|
|
69
|
+
# for some models.
|
|
70
|
+
_cp_plan = {
|
|
71
|
+
# Pattern of transformer_blocks.0, split_output=False:
|
|
72
|
+
# un-split input -> split -> to_qkv/...
|
|
73
|
+
# -> all2all
|
|
74
|
+
# -> attn (local head, full seqlen)
|
|
75
|
+
# -> all2all
|
|
76
|
+
# -> splited output
|
|
77
|
+
# Pattern of the rest transformer_blocks, split_output=False:
|
|
78
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
79
|
+
# -> all2all
|
|
80
|
+
# -> attn (local head, full seqlen)
|
|
81
|
+
# -> all2all
|
|
82
|
+
# -> splited output
|
|
83
|
+
# The `encoder_hidden_states` will be changed after each block forward,
|
|
84
|
+
# so we need to split it at the first block, and keep it splited (namely,
|
|
85
|
+
# automatically split by the all2all op after attn) for the rest blocks.
|
|
86
|
+
# The `out` tensor of local attn will be splited into `hidden_states` and
|
|
87
|
+
# `encoder_hidden_states` after each block forward, thus both of them
|
|
88
|
+
# will be automatically splited by all2all comm op after local attn.
|
|
89
|
+
"transformer_blocks.0": {
|
|
90
|
+
"hidden_states": ContextParallelInput(
|
|
91
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
92
|
+
),
|
|
93
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
94
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
95
|
+
),
|
|
96
|
+
},
|
|
97
|
+
# Pattern of the image_rotary_emb, split at every block, because the it
|
|
98
|
+
# is not automatically splited by all2all comm op and keep un-splited
|
|
99
|
+
# while the block forward finished:
|
|
100
|
+
# un-split input -> split output
|
|
101
|
+
# -> after block forward
|
|
102
|
+
# -> un-split input
|
|
103
|
+
# un-split input -> split output
|
|
104
|
+
# ...
|
|
105
|
+
"transformer_blocks.*": {
|
|
106
|
+
"image_rotary_emb": [
|
|
107
|
+
ContextParallelInput(
|
|
108
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
109
|
+
),
|
|
110
|
+
ContextParallelInput(
|
|
111
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
112
|
+
),
|
|
113
|
+
],
|
|
114
|
+
},
|
|
115
|
+
# NOTE: We should gather both hidden_states and encoder_hidden_states
|
|
116
|
+
# at the end of the last block. Because the subsequent op is:
|
|
117
|
+
# hidden_states = torch.cat([encoder_hidden_states, hidden_states])
|
|
118
|
+
f"transformer_blocks.{len(transformer.transformer_blocks) - 1}": [
|
|
119
|
+
ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
120
|
+
ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
121
|
+
],
|
|
122
|
+
}
|
|
123
|
+
return _cp_plan
|