cache-dit 1.0.9__py3-none-any.whl → 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/block_adapters/__init__.py +37 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +51 -3
- cache_dit/cache_factory/block_adapters/block_registers.py +41 -14
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +68 -30
- cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
- cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
- cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
- cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
- cache_dit/cache_factory/cache_interface.py +29 -3
- cache_dit/cache_factory/forward_pattern.py +14 -14
- cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -61
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
- cache_dit/parallelism/parallel_backend.py +2 -0
- cache_dit/parallelism/parallel_config.py +8 -1
- cache_dit/parallelism/parallel_interface.py +9 -4
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/utils.py +22 -2
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/METADATA +22 -13
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# Docstring references: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/_modeling_parallel.py#L185
|
|
2
|
+
# A dictionary where keys denote the input to be split across context parallel region, and the
|
|
3
|
+
# value denotes the sharding configuration.
|
|
4
|
+
# If the key is a string, it denotes the name of the parameter in the forward function.
|
|
5
|
+
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
|
|
6
|
+
# to be split across context parallel region.
|
|
7
|
+
# ContextParallelInputType = Dict[
|
|
8
|
+
# Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
|
|
9
|
+
# ]
|
|
10
|
+
|
|
11
|
+
# A dictionary where keys denote the output to be gathered across context parallel region, and the
|
|
12
|
+
# value denotes the gathering configuration.
|
|
13
|
+
# ContextParallelOutputType = Union[
|
|
14
|
+
# ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
|
|
15
|
+
# ]
|
|
16
|
+
|
|
17
|
+
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
|
|
18
|
+
# the module should be split/gathered across context parallel region.
|
|
19
|
+
# ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
|
|
20
|
+
|
|
21
|
+
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
|
|
22
|
+
#
|
|
23
|
+
# Each model should define a _cp_plan attribute that contains information on how to shard/gather
|
|
24
|
+
# tensors at different stages of the forward:
|
|
25
|
+
#
|
|
26
|
+
# ```python
|
|
27
|
+
# _cp_plan = {
|
|
28
|
+
# "": {
|
|
29
|
+
# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
|
30
|
+
# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
|
31
|
+
# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
|
32
|
+
# },
|
|
33
|
+
# "pos_embed": {
|
|
34
|
+
# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
|
35
|
+
# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
|
36
|
+
# },
|
|
37
|
+
# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
38
|
+
# }
|
|
39
|
+
# ```
|
|
40
|
+
#
|
|
41
|
+
# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
|
|
42
|
+
# split/gathered according to this at the respective module level. Here, the following happens:
|
|
43
|
+
# - "":
|
|
44
|
+
# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
|
|
45
|
+
# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
|
|
46
|
+
# - "pos_embed":
|
|
47
|
+
# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
|
|
48
|
+
# we can individually specify how they should be split
|
|
49
|
+
# - "proj_out":
|
|
50
|
+
# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
|
|
51
|
+
# layer forward has run).
|
|
52
|
+
#
|
|
53
|
+
# ContextParallelInput:
|
|
54
|
+
# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
|
|
55
|
+
#
|
|
56
|
+
# ContextParallelOutput:
|
|
57
|
+
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
|
|
58
|
+
|
|
59
|
+
import torch
|
|
60
|
+
from typing import Optional
|
|
61
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
from diffusers.models._modeling_parallel import (
|
|
65
|
+
ContextParallelInput,
|
|
66
|
+
ContextParallelOutput,
|
|
67
|
+
ContextParallelModelPlan,
|
|
68
|
+
)
|
|
69
|
+
except ImportError:
|
|
70
|
+
raise ImportError(
|
|
71
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
72
|
+
"Please install latest version of diffusers from source: \n"
|
|
73
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
74
|
+
)
|
|
75
|
+
from .cp_plan_registers import (
|
|
76
|
+
ContextParallelismPlanner,
|
|
77
|
+
ContextParallelismPlannerRegister,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
from cache_dit.logger import init_logger
|
|
81
|
+
|
|
82
|
+
logger = init_logger(__name__)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
__all__ = [
|
|
86
|
+
"ContextParallelismPlanner",
|
|
87
|
+
"ContextParallelismPlannerRegister",
|
|
88
|
+
"FluxContextParallelismPlanner",
|
|
89
|
+
"QwenImageContextParallelismPlanner",
|
|
90
|
+
"WanContextParallelismPlanner",
|
|
91
|
+
"LTXVideoContextParallelismPlanner",
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# Register context parallelism planner for models
|
|
96
|
+
@ContextParallelismPlannerRegister.register("Flux")
|
|
97
|
+
class FluxContextParallelismPlanner(ContextParallelismPlanner):
|
|
98
|
+
def apply(
|
|
99
|
+
self,
|
|
100
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
101
|
+
**kwargs,
|
|
102
|
+
) -> ContextParallelModelPlan:
|
|
103
|
+
if transformer is not None:
|
|
104
|
+
from diffusers import FluxTransformer2DModel
|
|
105
|
+
|
|
106
|
+
assert isinstance(
|
|
107
|
+
transformer, FluxTransformer2DModel
|
|
108
|
+
), "Transformer must be an instance of FluxTransformer2DModel"
|
|
109
|
+
if hasattr(transformer, "_cp_plan"):
|
|
110
|
+
return transformer._cp_plan
|
|
111
|
+
|
|
112
|
+
_cp_plan = {
|
|
113
|
+
"": {
|
|
114
|
+
"hidden_states": ContextParallelInput(
|
|
115
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
116
|
+
),
|
|
117
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
118
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
119
|
+
),
|
|
120
|
+
"img_ids": ContextParallelInput(
|
|
121
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
122
|
+
),
|
|
123
|
+
"txt_ids": ContextParallelInput(
|
|
124
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
125
|
+
),
|
|
126
|
+
},
|
|
127
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
128
|
+
}
|
|
129
|
+
return _cp_plan
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@ContextParallelismPlannerRegister.register("QwenImage")
|
|
133
|
+
class QwenImageContextParallelismPlanner(ContextParallelismPlanner):
|
|
134
|
+
def apply(
|
|
135
|
+
self,
|
|
136
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
137
|
+
**kwargs,
|
|
138
|
+
) -> ContextParallelModelPlan:
|
|
139
|
+
if transformer is not None:
|
|
140
|
+
from diffusers import QwenImageTransformer2DModel
|
|
141
|
+
|
|
142
|
+
assert isinstance(
|
|
143
|
+
transformer, QwenImageTransformer2DModel
|
|
144
|
+
), "Transformer must be an instance of QwenImageTransformer2DModel"
|
|
145
|
+
if hasattr(transformer, "_cp_plan"):
|
|
146
|
+
return transformer._cp_plan
|
|
147
|
+
|
|
148
|
+
_cp_plan = _cp_plan = {
|
|
149
|
+
"": {
|
|
150
|
+
"hidden_states": ContextParallelInput(
|
|
151
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
152
|
+
),
|
|
153
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
154
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
155
|
+
),
|
|
156
|
+
"encoder_hidden_states_mask": ContextParallelInput(
|
|
157
|
+
split_dim=1, expected_dims=2, split_output=False
|
|
158
|
+
),
|
|
159
|
+
},
|
|
160
|
+
"pos_embed": {
|
|
161
|
+
0: ContextParallelInput(
|
|
162
|
+
split_dim=0, expected_dims=2, split_output=True
|
|
163
|
+
),
|
|
164
|
+
1: ContextParallelInput(
|
|
165
|
+
split_dim=0, expected_dims=2, split_output=True
|
|
166
|
+
),
|
|
167
|
+
},
|
|
168
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
169
|
+
}
|
|
170
|
+
return _cp_plan
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
# TODO: Add WanVACETransformer3DModel context parallelism planner.
|
|
174
|
+
# NOTE: We choice to use full name to avoid name conflict between
|
|
175
|
+
# WanTransformer3DModel and WanVACETransformer3DModel.
|
|
176
|
+
@ContextParallelismPlannerRegister.register("WanTransformer3D")
|
|
177
|
+
class WanContextParallelismPlanner(ContextParallelismPlanner):
|
|
178
|
+
def apply(
|
|
179
|
+
self,
|
|
180
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
181
|
+
**kwargs,
|
|
182
|
+
) -> ContextParallelModelPlan:
|
|
183
|
+
if transformer is not None:
|
|
184
|
+
from diffusers import WanTransformer3DModel
|
|
185
|
+
|
|
186
|
+
assert isinstance(
|
|
187
|
+
transformer, WanTransformer3DModel
|
|
188
|
+
), "Transformer must be an instance of WanTransformer3DModel"
|
|
189
|
+
if hasattr(transformer, "_cp_plan"):
|
|
190
|
+
return transformer._cp_plan
|
|
191
|
+
|
|
192
|
+
_cp_plan = {
|
|
193
|
+
"rope": {
|
|
194
|
+
0: ContextParallelInput(
|
|
195
|
+
split_dim=1, expected_dims=4, split_output=True
|
|
196
|
+
),
|
|
197
|
+
1: ContextParallelInput(
|
|
198
|
+
split_dim=1, expected_dims=4, split_output=True
|
|
199
|
+
),
|
|
200
|
+
},
|
|
201
|
+
"blocks.0": {
|
|
202
|
+
"hidden_states": ContextParallelInput(
|
|
203
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
204
|
+
),
|
|
205
|
+
},
|
|
206
|
+
"blocks.*": {
|
|
207
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
208
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
209
|
+
),
|
|
210
|
+
},
|
|
211
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
212
|
+
}
|
|
213
|
+
return _cp_plan
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@ContextParallelismPlannerRegister.register("LTXVideo")
|
|
217
|
+
class LTXVideoContextParallelismPlanner(ContextParallelismPlanner):
|
|
218
|
+
def apply(
|
|
219
|
+
self,
|
|
220
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
221
|
+
**kwargs,
|
|
222
|
+
) -> ContextParallelModelPlan:
|
|
223
|
+
if transformer is not None:
|
|
224
|
+
from diffusers import LTXVideoTransformer3DModel
|
|
225
|
+
|
|
226
|
+
assert isinstance(
|
|
227
|
+
transformer, LTXVideoTransformer3DModel
|
|
228
|
+
), "Transformer must be an instance of LTXVideoTransformer3DModel"
|
|
229
|
+
if hasattr(transformer, "_cp_plan"):
|
|
230
|
+
return transformer._cp_plan
|
|
231
|
+
|
|
232
|
+
_cp_plan = {
|
|
233
|
+
"": {
|
|
234
|
+
"hidden_states": ContextParallelInput(
|
|
235
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
236
|
+
),
|
|
237
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
238
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
239
|
+
),
|
|
240
|
+
"encoder_attention_mask": ContextParallelInput(
|
|
241
|
+
split_dim=1, expected_dims=2, split_output=False
|
|
242
|
+
),
|
|
243
|
+
},
|
|
244
|
+
"rope": {
|
|
245
|
+
0: ContextParallelInput(
|
|
246
|
+
split_dim=1, expected_dims=3, split_output=True
|
|
247
|
+
),
|
|
248
|
+
1: ContextParallelInput(
|
|
249
|
+
split_dim=1, expected_dims=3, split_output=True
|
|
250
|
+
),
|
|
251
|
+
},
|
|
252
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
253
|
+
}
|
|
254
|
+
return _cp_plan
|
|
@@ -6,22 +6,10 @@ from cache_dit.logger import init_logger
|
|
|
6
6
|
logger = init_logger(__name__)
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
try:
|
|
10
|
-
from diffusers import ContextParallelConfig
|
|
11
|
-
|
|
12
|
-
def native_diffusers_parallelism_available() -> bool:
|
|
13
|
-
return True
|
|
14
|
-
|
|
15
|
-
except ImportError:
|
|
16
|
-
ContextParallelConfig = None
|
|
17
|
-
|
|
18
|
-
def native_diffusers_parallelism_available() -> bool:
|
|
19
|
-
return False
|
|
20
|
-
|
|
21
|
-
|
|
22
9
|
from diffusers.models.modeling_utils import ModelMixin
|
|
23
10
|
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
24
11
|
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
12
|
+
from .context_parallelism import maybe_enable_context_parallelism
|
|
25
13
|
|
|
26
14
|
|
|
27
15
|
def maybe_enable_parallelism(
|
|
@@ -40,54 +28,22 @@ def maybe_enable_parallelism(
|
|
|
40
28
|
f" but got {type(parallelism_config)}"
|
|
41
29
|
)
|
|
42
30
|
|
|
31
|
+
assert parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER, (
|
|
32
|
+
f"parallelism backend must be {ParallelismBackend.NATIVE_DIFFUSER}, "
|
|
33
|
+
f"but got {parallelism_config.backend}"
|
|
34
|
+
)
|
|
35
|
+
|
|
43
36
|
if (
|
|
44
|
-
parallelism_config.
|
|
45
|
-
|
|
37
|
+
parallelism_config.ulysses_size is not None
|
|
38
|
+
or parallelism_config.ring_size is not None
|
|
46
39
|
):
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
parallelism_config
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
if cp_config is not None:
|
|
57
|
-
attention_backend = parallelism_config.parallel_kwargs.get(
|
|
58
|
-
"attention_backend", None
|
|
59
|
-
)
|
|
60
|
-
if hasattr(transformer, "enable_parallelism"):
|
|
61
|
-
if hasattr(transformer, "set_attention_backend"):
|
|
62
|
-
# _native_cudnn, flash, etc.
|
|
63
|
-
if attention_backend is None:
|
|
64
|
-
# Now only _native_cudnn is supported for parallelism
|
|
65
|
-
# issue: https://github.com/huggingface/diffusers/pull/12443
|
|
66
|
-
transformer.set_attention_backend("_native_cudnn")
|
|
67
|
-
logger.warning(
|
|
68
|
-
"attention_backend is None, set default attention backend "
|
|
69
|
-
"to _native_cudnn for parallelism because of the issue: "
|
|
70
|
-
"https://github.com/huggingface/diffusers/pull/12443"
|
|
71
|
-
)
|
|
72
|
-
else:
|
|
73
|
-
transformer.set_attention_backend(attention_backend)
|
|
74
|
-
logger.info(
|
|
75
|
-
"Found attention_backend from config, set attention "
|
|
76
|
-
f"backend to: {attention_backend}"
|
|
77
|
-
)
|
|
78
|
-
cp_plan = parallelism_config.parallel_kwargs.get(
|
|
79
|
-
"cp_plan", None
|
|
80
|
-
)
|
|
81
|
-
if cp_plan is not None:
|
|
82
|
-
logger.info(
|
|
83
|
-
f"Using custom context parallelism plan: {cp_plan}"
|
|
84
|
-
)
|
|
85
|
-
transformer.enable_parallelism(
|
|
86
|
-
config=cp_config, cp_plan=cp_plan
|
|
87
|
-
)
|
|
88
|
-
else:
|
|
89
|
-
raise ValueError(
|
|
90
|
-
f"{transformer.__class__.__name__} does not support context parallelism."
|
|
91
|
-
)
|
|
92
|
-
|
|
40
|
+
transformer = maybe_enable_context_parallelism(
|
|
41
|
+
transformer,
|
|
42
|
+
parallelism_config,
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
"NATIVE_DIFFUSER backend only support context parallelism now. "
|
|
47
|
+
"Please set ulysses_size or ring_size in parallelism_config."
|
|
48
|
+
)
|
|
93
49
|
return transformer
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
6
|
+
|
|
7
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
8
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
9
|
+
|
|
10
|
+
from cache_dit.logger import init_logger
|
|
11
|
+
|
|
12
|
+
logger = init_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def maybe_enable_parallelism(
|
|
16
|
+
transformer: torch.nn.Module | ModelMixin,
|
|
17
|
+
parallelism_config: Optional[ParallelismConfig],
|
|
18
|
+
) -> torch.nn.Module:
|
|
19
|
+
assert isinstance(transformer, torch.nn.Module), (
|
|
20
|
+
"transformer must be an instance of torch.nn.Module, "
|
|
21
|
+
f"but got {type(transformer)}"
|
|
22
|
+
)
|
|
23
|
+
assert isinstance(transformer, ModelMixin), (
|
|
24
|
+
"transformer must be an instance of diffusers' ModelMixin, "
|
|
25
|
+
f"but got {type(transformer)}"
|
|
26
|
+
)
|
|
27
|
+
if parallelism_config is None:
|
|
28
|
+
return transformer
|
|
29
|
+
|
|
30
|
+
assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
|
|
31
|
+
"parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
|
|
32
|
+
f"but got {parallelism_config.backend}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
assert isinstance(parallelism_config, ParallelismConfig), (
|
|
36
|
+
"parallelism_config must be an instance of ParallelismConfig"
|
|
37
|
+
f" but got {type(parallelism_config)}"
|
|
38
|
+
)
|
|
39
|
+
assert (
|
|
40
|
+
parallelism_config.ulysses_size is None
|
|
41
|
+
and parallelism_config.ring_size is None
|
|
42
|
+
), (
|
|
43
|
+
"Ulysses/Ring parallelism is not supported in Native_PyTorch backend. "
|
|
44
|
+
"Please set it to None in parallelism_config."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
if (
|
|
48
|
+
parallelism_config.tp_size is not None
|
|
49
|
+
and parallelism_config.tp_size > 1
|
|
50
|
+
):
|
|
51
|
+
from .tensor_parallelism import maybe_enable_tensor_parallelism
|
|
52
|
+
|
|
53
|
+
transformer = maybe_enable_tensor_parallelism(
|
|
54
|
+
transformer=transformer,
|
|
55
|
+
parallelism_config=parallelism_config,
|
|
56
|
+
)
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"NATIVE_PYTORCH only supported tensor parallelism now. "
|
|
60
|
+
"Please set tp_size > 1 for tensor parallelism."
|
|
61
|
+
)
|
|
62
|
+
return transformer
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import einops
|
|
3
|
+
except ImportError:
|
|
4
|
+
raise ImportError(
|
|
5
|
+
"Metrics functionality requires the 'parallelism' extra dependencies. "
|
|
6
|
+
"Install with:\npip install cache-dit[parallelism]"
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
12
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
13
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
14
|
+
from cache_dit.logger import init_logger
|
|
15
|
+
from .tp_planners import *
|
|
16
|
+
|
|
17
|
+
logger = init_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def maybe_enable_tensor_parallelism(
|
|
21
|
+
transformer: torch.nn.Module | ModelMixin,
|
|
22
|
+
parallelism_config: Optional[ParallelismConfig],
|
|
23
|
+
) -> torch.nn.Module:
|
|
24
|
+
assert isinstance(transformer, torch.nn.Module), (
|
|
25
|
+
"transformer must be an instance of torch.nn.Module, "
|
|
26
|
+
f"but got {type(transformer)}"
|
|
27
|
+
)
|
|
28
|
+
assert isinstance(transformer, ModelMixin), (
|
|
29
|
+
"transformer must be an instance of diffusers' ModelMixin, "
|
|
30
|
+
f"but got {type(transformer)}"
|
|
31
|
+
)
|
|
32
|
+
if parallelism_config is None:
|
|
33
|
+
return transformer
|
|
34
|
+
|
|
35
|
+
assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
|
|
36
|
+
"parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
|
|
37
|
+
f"but got {parallelism_config.backend}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
extra_parallel_kwargs = {}
|
|
41
|
+
if parallelism_config.parallel_kwargs is not None:
|
|
42
|
+
extra_parallel_kwargs = parallelism_config.parallel_kwargs
|
|
43
|
+
|
|
44
|
+
return TensorParallelismPlannerRegister.get_planner(transformer)().apply(
|
|
45
|
+
transformer=transformer,
|
|
46
|
+
parallelism_config=parallelism_config,
|
|
47
|
+
**extra_parallel_kwargs,
|
|
48
|
+
)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from diffusers.models.transformers.transformer_flux import (
|
|
3
|
+
FluxSingleTransformerBlock,
|
|
4
|
+
)
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
8
|
+
from torch.distributed._tensor import Replicate
|
|
9
|
+
from torch.distributed.tensor.parallel import (
|
|
10
|
+
ColwiseParallel,
|
|
11
|
+
RowwiseParallel,
|
|
12
|
+
parallelize_module,
|
|
13
|
+
)
|
|
14
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
15
|
+
from .tp_plan_registers import (
|
|
16
|
+
TensorParallelismPlanner,
|
|
17
|
+
TensorParallelismPlannerRegister,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from cache_dit.logger import init_logger
|
|
21
|
+
|
|
22
|
+
logger = init_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@TensorParallelismPlannerRegister.register("Flux")
|
|
26
|
+
class FluxTensorParallelismPlanner(TensorParallelismPlanner):
|
|
27
|
+
def apply(
|
|
28
|
+
self,
|
|
29
|
+
transformer: torch.nn.Module,
|
|
30
|
+
parallelism_config: ParallelismConfig,
|
|
31
|
+
**kwargs,
|
|
32
|
+
) -> torch.nn.Module:
|
|
33
|
+
assert (
|
|
34
|
+
parallelism_config.tp_size is not None
|
|
35
|
+
and parallelism_config.tp_size > 1
|
|
36
|
+
), (
|
|
37
|
+
"parallel_config.tp_size must be set and greater than 1 for "
|
|
38
|
+
"tensor parallelism"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
device_type = torch.accelerator.current_accelerator().type
|
|
42
|
+
tp_mesh: DeviceMesh = init_device_mesh(
|
|
43
|
+
device_type=device_type,
|
|
44
|
+
mesh_shape=[parallelism_config.tp_size],
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
transformer = self.parallelize_transformer(
|
|
48
|
+
transformer=transformer,
|
|
49
|
+
tp_mesh=tp_mesh,
|
|
50
|
+
)
|
|
51
|
+
# TODO: Parallelize t5 text encoder via `apply_extra`
|
|
52
|
+
# abstract method and `extra_parallel_kwargs` ?
|
|
53
|
+
|
|
54
|
+
return transformer
|
|
55
|
+
|
|
56
|
+
def parallelize_t5(
|
|
57
|
+
self,
|
|
58
|
+
text_encoder: nn.Module,
|
|
59
|
+
tp_mesh: DeviceMesh,
|
|
60
|
+
):
|
|
61
|
+
for i, block in enumerate(text_encoder.encoder.block):
|
|
62
|
+
block.layer[0].SelfAttention.n_heads //= tp_mesh.size()
|
|
63
|
+
block.layer[0].SelfAttention.inner_dim //= tp_mesh.size()
|
|
64
|
+
layer_plan = {
|
|
65
|
+
"layer.0.SelfAttention.q": ColwiseParallel(),
|
|
66
|
+
"layer.0.SelfAttention.k": ColwiseParallel(),
|
|
67
|
+
"layer.0.SelfAttention.v": ColwiseParallel(),
|
|
68
|
+
"layer.0.SelfAttention.o": RowwiseParallel(),
|
|
69
|
+
"layer.1.DenseReluDense.wi_0": ColwiseParallel(),
|
|
70
|
+
"layer.1.DenseReluDense.wi_1": ColwiseParallel(),
|
|
71
|
+
"layer.1.DenseReluDense.wo": RowwiseParallel(),
|
|
72
|
+
}
|
|
73
|
+
if i == 0:
|
|
74
|
+
layer_plan["layer.0.SelfAttention.relative_attention_bias"] = (
|
|
75
|
+
ColwiseParallel()
|
|
76
|
+
)
|
|
77
|
+
parallelize_module(
|
|
78
|
+
module=block,
|
|
79
|
+
device_mesh=tp_mesh,
|
|
80
|
+
parallelize_plan=layer_plan,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return text_encoder
|
|
84
|
+
|
|
85
|
+
def parallelize_transformer(
|
|
86
|
+
self,
|
|
87
|
+
transformer: nn.Module,
|
|
88
|
+
tp_mesh: DeviceMesh,
|
|
89
|
+
):
|
|
90
|
+
for _, block in transformer.transformer_blocks.named_children():
|
|
91
|
+
block.attn.heads //= tp_mesh.size()
|
|
92
|
+
layer_plan = {
|
|
93
|
+
"attn.to_q": ColwiseParallel(),
|
|
94
|
+
"attn.to_k": ColwiseParallel(),
|
|
95
|
+
"attn.to_v": ColwiseParallel(),
|
|
96
|
+
"attn.to_out.0": RowwiseParallel(),
|
|
97
|
+
"norm1.linear": ColwiseParallel(output_layouts=Replicate()),
|
|
98
|
+
"ff.net.0.proj": ColwiseParallel(),
|
|
99
|
+
"ff.net.2": RowwiseParallel(),
|
|
100
|
+
"attn.add_q_proj": ColwiseParallel(),
|
|
101
|
+
"attn.add_k_proj": ColwiseParallel(),
|
|
102
|
+
"attn.add_v_proj": ColwiseParallel(),
|
|
103
|
+
"attn.to_add_out": RowwiseParallel(),
|
|
104
|
+
"norm1_context.linear": ColwiseParallel(
|
|
105
|
+
output_layouts=Replicate()
|
|
106
|
+
),
|
|
107
|
+
"ff_context.net.0.proj": ColwiseParallel(),
|
|
108
|
+
"ff_context.net.2": RowwiseParallel(),
|
|
109
|
+
}
|
|
110
|
+
parallelize_module(
|
|
111
|
+
module=block,
|
|
112
|
+
device_mesh=tp_mesh,
|
|
113
|
+
parallelize_plan=layer_plan,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# NOTE: special handling for FluxSingleTransformerBlock, we have to
|
|
117
|
+
# rearrange the proj_out weight because it contains both out and down
|
|
118
|
+
# projection weights in a single matrix.
|
|
119
|
+
def rearrange_proj_out_weight(
|
|
120
|
+
single_block: FluxSingleTransformerBlock, tp_group_size
|
|
121
|
+
):
|
|
122
|
+
# rowwise
|
|
123
|
+
hidden_dim = 3072
|
|
124
|
+
requires_grad = single_block.proj_out.weight.requires_grad
|
|
125
|
+
linear2_weight_data = (
|
|
126
|
+
single_block.proj_out.weight.data.T.detach().clone()
|
|
127
|
+
)
|
|
128
|
+
out_weight = linear2_weight_data[:hidden_dim, ...]
|
|
129
|
+
out_weight = rearrange(
|
|
130
|
+
out_weight, "(G D) C -> G D C", G=tp_group_size
|
|
131
|
+
)
|
|
132
|
+
down_weight = linear2_weight_data.data[hidden_dim:, ...]
|
|
133
|
+
down_weight = rearrange(
|
|
134
|
+
down_weight, "(G D) C -> G D C", G=tp_group_size
|
|
135
|
+
)
|
|
136
|
+
new_linear2_weight = torch.cat([out_weight, down_weight], dim=1)
|
|
137
|
+
new_linear2_weight = rearrange(
|
|
138
|
+
new_linear2_weight, "G D C -> (G D) C"
|
|
139
|
+
)
|
|
140
|
+
single_block.proj_out.weight.data.copy_(new_linear2_weight.T)
|
|
141
|
+
single_block.proj_out.weight.requires_grad_(requires_grad)
|
|
142
|
+
|
|
143
|
+
for _, block in transformer.single_transformer_blocks.named_children():
|
|
144
|
+
rearrange_proj_out_weight(block, tp_mesh.size())
|
|
145
|
+
block.attn.heads //= tp_mesh.size()
|
|
146
|
+
layer_plan = {
|
|
147
|
+
"attn.to_q": ColwiseParallel(),
|
|
148
|
+
"attn.to_k": ColwiseParallel(),
|
|
149
|
+
"attn.to_v": ColwiseParallel(),
|
|
150
|
+
"proj_mlp": ColwiseParallel(),
|
|
151
|
+
"proj_out": RowwiseParallel(),
|
|
152
|
+
"norm.linear": ColwiseParallel(output_layouts=Replicate()),
|
|
153
|
+
}
|
|
154
|
+
parallelize_module(
|
|
155
|
+
module=block,
|
|
156
|
+
device_mesh=tp_mesh,
|
|
157
|
+
parallelize_plan=layer_plan,
|
|
158
|
+
)
|
|
159
|
+
return transformer
|