cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from cache_dit.logger import init_logger
|
|
5
|
+
|
|
6
|
+
logger = init_logger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
10
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
11
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
12
|
+
from .context_parallelism import maybe_enable_context_parallelism
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def maybe_enable_parallelism(
|
|
16
|
+
transformer: torch.nn.Module,
|
|
17
|
+
parallelism_config: Optional[ParallelismConfig],
|
|
18
|
+
) -> torch.nn.Module:
|
|
19
|
+
assert isinstance(transformer, ModelMixin), (
|
|
20
|
+
"transformer must be an instance of diffusers' ModelMixin, "
|
|
21
|
+
f"but got {type(transformer)}"
|
|
22
|
+
)
|
|
23
|
+
if parallelism_config is None:
|
|
24
|
+
return transformer
|
|
25
|
+
|
|
26
|
+
assert isinstance(parallelism_config, ParallelismConfig), (
|
|
27
|
+
"parallelism_config must be an instance of ParallelismConfig"
|
|
28
|
+
f" but got {type(parallelism_config)}"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
assert parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER, (
|
|
32
|
+
f"parallelism backend must be {ParallelismBackend.NATIVE_DIFFUSER}, "
|
|
33
|
+
f"but got {parallelism_config.backend}"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if (
|
|
37
|
+
parallelism_config.ulysses_size is not None
|
|
38
|
+
or parallelism_config.ring_size is not None
|
|
39
|
+
):
|
|
40
|
+
transformer = maybe_enable_context_parallelism(
|
|
41
|
+
transformer,
|
|
42
|
+
parallelism_config,
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
"NATIVE_DIFFUSER backend only support context parallelism now. "
|
|
47
|
+
"Please set ulysses_size or ring_size in parallelism_config."
|
|
48
|
+
)
|
|
49
|
+
return transformer
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
6
|
+
|
|
7
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
8
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
9
|
+
|
|
10
|
+
from cache_dit.logger import init_logger
|
|
11
|
+
|
|
12
|
+
logger = init_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def maybe_enable_parallelism(
|
|
16
|
+
transformer: torch.nn.Module | ModelMixin,
|
|
17
|
+
parallelism_config: Optional[ParallelismConfig],
|
|
18
|
+
) -> torch.nn.Module:
|
|
19
|
+
assert isinstance(transformer, torch.nn.Module), (
|
|
20
|
+
"transformer must be an instance of torch.nn.Module, "
|
|
21
|
+
f"but got {type(transformer)}"
|
|
22
|
+
)
|
|
23
|
+
assert isinstance(transformer, ModelMixin), (
|
|
24
|
+
"transformer must be an instance of diffusers' ModelMixin, "
|
|
25
|
+
f"but got {type(transformer)}"
|
|
26
|
+
)
|
|
27
|
+
if parallelism_config is None:
|
|
28
|
+
return transformer
|
|
29
|
+
|
|
30
|
+
assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
|
|
31
|
+
"parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
|
|
32
|
+
f"but got {parallelism_config.backend}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
assert isinstance(parallelism_config, ParallelismConfig), (
|
|
36
|
+
"parallelism_config must be an instance of ParallelismConfig"
|
|
37
|
+
f" but got {type(parallelism_config)}"
|
|
38
|
+
)
|
|
39
|
+
assert (
|
|
40
|
+
parallelism_config.ulysses_size is None
|
|
41
|
+
and parallelism_config.ring_size is None
|
|
42
|
+
), (
|
|
43
|
+
"Ulysses/Ring parallelism is not supported in Native_PyTorch backend. "
|
|
44
|
+
"Please set it to None in parallelism_config."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
if (
|
|
48
|
+
parallelism_config.tp_size is not None
|
|
49
|
+
and parallelism_config.tp_size > 1
|
|
50
|
+
):
|
|
51
|
+
from .tensor_parallelism import maybe_enable_tensor_parallelism
|
|
52
|
+
|
|
53
|
+
transformer = maybe_enable_tensor_parallelism(
|
|
54
|
+
transformer=transformer,
|
|
55
|
+
parallelism_config=parallelism_config,
|
|
56
|
+
)
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"NATIVE_PYTORCH only supported tensor parallelism now. "
|
|
60
|
+
"Please set tp_size > 1 for tensor parallelism."
|
|
61
|
+
)
|
|
62
|
+
return transformer
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import einops
|
|
3
|
+
except ImportError:
|
|
4
|
+
raise ImportError(
|
|
5
|
+
"Metrics functionality requires the 'parallelism' extra dependencies. "
|
|
6
|
+
"Install with:\npip install cache-dit[parallelism]"
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
12
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
13
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
14
|
+
from cache_dit.logger import init_logger
|
|
15
|
+
from .tp_planners import *
|
|
16
|
+
|
|
17
|
+
logger = init_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def maybe_enable_tensor_parallelism(
|
|
21
|
+
transformer: torch.nn.Module | ModelMixin,
|
|
22
|
+
parallelism_config: Optional[ParallelismConfig],
|
|
23
|
+
) -> torch.nn.Module:
|
|
24
|
+
assert isinstance(transformer, torch.nn.Module), (
|
|
25
|
+
"transformer must be an instance of torch.nn.Module, "
|
|
26
|
+
f"but got {type(transformer)}"
|
|
27
|
+
)
|
|
28
|
+
assert isinstance(transformer, ModelMixin), (
|
|
29
|
+
"transformer must be an instance of diffusers' ModelMixin, "
|
|
30
|
+
f"but got {type(transformer)}"
|
|
31
|
+
)
|
|
32
|
+
if parallelism_config is None:
|
|
33
|
+
return transformer
|
|
34
|
+
|
|
35
|
+
assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
|
|
36
|
+
"parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
|
|
37
|
+
f"but got {parallelism_config.backend}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
extra_parallel_kwargs = {}
|
|
41
|
+
if parallelism_config.parallel_kwargs is not None:
|
|
42
|
+
extra_parallel_kwargs = parallelism_config.parallel_kwargs
|
|
43
|
+
|
|
44
|
+
return TensorParallelismPlannerRegister.get_planner(transformer)().apply(
|
|
45
|
+
transformer=transformer,
|
|
46
|
+
parallelism_config=parallelism_config,
|
|
47
|
+
**extra_parallel_kwargs,
|
|
48
|
+
)
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from diffusers.models.transformers.transformer_flux import (
|
|
3
|
+
FluxSingleTransformerBlock,
|
|
4
|
+
)
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
8
|
+
from torch.distributed._tensor import Replicate
|
|
9
|
+
from torch.distributed.tensor.parallel import (
|
|
10
|
+
ColwiseParallel,
|
|
11
|
+
RowwiseParallel,
|
|
12
|
+
parallelize_module,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from cache_dit.logger import init_logger
|
|
16
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
17
|
+
|
|
18
|
+
from .tp_plan_registers import (
|
|
19
|
+
TensorParallelismPlanner,
|
|
20
|
+
TensorParallelismPlannerRegister,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
logger = init_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@TensorParallelismPlannerRegister.register("Chroma")
|
|
27
|
+
@TensorParallelismPlannerRegister.register("HunyuanImage")
|
|
28
|
+
@TensorParallelismPlannerRegister.register("HunyuanVideo")
|
|
29
|
+
@TensorParallelismPlannerRegister.register("Flux")
|
|
30
|
+
class FluxTensorParallelismPlanner(TensorParallelismPlanner):
|
|
31
|
+
def apply(
|
|
32
|
+
self,
|
|
33
|
+
transformer: torch.nn.Module,
|
|
34
|
+
parallelism_config: ParallelismConfig,
|
|
35
|
+
**kwargs,
|
|
36
|
+
) -> torch.nn.Module:
|
|
37
|
+
assert (
|
|
38
|
+
parallelism_config.tp_size is not None
|
|
39
|
+
and parallelism_config.tp_size > 1
|
|
40
|
+
), (
|
|
41
|
+
"parallel_config.tp_size must be set and greater than 1 for "
|
|
42
|
+
"tensor parallelism"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
device_type = torch.accelerator.current_accelerator().type
|
|
46
|
+
tp_mesh: DeviceMesh = init_device_mesh(
|
|
47
|
+
device_type=device_type,
|
|
48
|
+
mesh_shape=[parallelism_config.tp_size],
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
transformer = self.parallelize_transformer(
|
|
52
|
+
transformer=transformer,
|
|
53
|
+
tp_mesh=tp_mesh,
|
|
54
|
+
)
|
|
55
|
+
# TODO: Parallelize t5 text encoder via `apply_extra`
|
|
56
|
+
# abstract method and `extra_parallel_kwargs` ?
|
|
57
|
+
|
|
58
|
+
return transformer
|
|
59
|
+
|
|
60
|
+
def parallelize_t5(
|
|
61
|
+
self,
|
|
62
|
+
text_encoder: nn.Module,
|
|
63
|
+
tp_mesh: DeviceMesh,
|
|
64
|
+
):
|
|
65
|
+
for i, block in enumerate(text_encoder.encoder.block):
|
|
66
|
+
block.layer[0].SelfAttention.n_heads //= tp_mesh.size()
|
|
67
|
+
block.layer[0].SelfAttention.inner_dim //= tp_mesh.size()
|
|
68
|
+
layer_plan = {
|
|
69
|
+
"layer.0.SelfAttention.q": ColwiseParallel(),
|
|
70
|
+
"layer.0.SelfAttention.k": ColwiseParallel(),
|
|
71
|
+
"layer.0.SelfAttention.v": ColwiseParallel(),
|
|
72
|
+
"layer.0.SelfAttention.o": RowwiseParallel(),
|
|
73
|
+
"layer.1.DenseReluDense.wi_0": ColwiseParallel(),
|
|
74
|
+
"layer.1.DenseReluDense.wi_1": ColwiseParallel(),
|
|
75
|
+
"layer.1.DenseReluDense.wo": RowwiseParallel(),
|
|
76
|
+
}
|
|
77
|
+
if i == 0:
|
|
78
|
+
layer_plan["layer.0.SelfAttention.relative_attention_bias"] = (
|
|
79
|
+
ColwiseParallel()
|
|
80
|
+
)
|
|
81
|
+
parallelize_module(
|
|
82
|
+
module=block,
|
|
83
|
+
device_mesh=tp_mesh,
|
|
84
|
+
parallelize_plan=layer_plan,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return text_encoder
|
|
88
|
+
|
|
89
|
+
def parallelize_transformer(
|
|
90
|
+
self,
|
|
91
|
+
transformer: nn.Module,
|
|
92
|
+
tp_mesh: DeviceMesh,
|
|
93
|
+
):
|
|
94
|
+
for _, block in transformer.transformer_blocks.named_children():
|
|
95
|
+
block.attn.heads //= tp_mesh.size()
|
|
96
|
+
layer_plan = {
|
|
97
|
+
"attn.to_q": ColwiseParallel(),
|
|
98
|
+
"attn.to_k": ColwiseParallel(),
|
|
99
|
+
"attn.to_v": ColwiseParallel(),
|
|
100
|
+
"attn.to_out.0": RowwiseParallel(),
|
|
101
|
+
"ff.net.0.proj": ColwiseParallel(),
|
|
102
|
+
"ff.net.2": RowwiseParallel(),
|
|
103
|
+
"attn.add_q_proj": ColwiseParallel(),
|
|
104
|
+
"attn.add_k_proj": ColwiseParallel(),
|
|
105
|
+
"attn.add_v_proj": ColwiseParallel(),
|
|
106
|
+
"attn.to_add_out": RowwiseParallel(),
|
|
107
|
+
"ff_context.net.0.proj": ColwiseParallel(),
|
|
108
|
+
"ff_context.net.2": RowwiseParallel(),
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
if getattr(block.norm1, "linear", None) is not None:
|
|
112
|
+
layer_plan["norm1.linear"] = ColwiseParallel(
|
|
113
|
+
output_layouts=Replicate()
|
|
114
|
+
)
|
|
115
|
+
if getattr(block.norm1_context, "linear", None) is not None:
|
|
116
|
+
layer_plan["norm1_context.linear"] = ColwiseParallel(
|
|
117
|
+
output_layouts=Replicate()
|
|
118
|
+
)
|
|
119
|
+
parallelize_module(
|
|
120
|
+
module=block,
|
|
121
|
+
device_mesh=tp_mesh,
|
|
122
|
+
parallelize_plan=layer_plan,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# NOTE: special handling for FluxSingleTransformerBlock, we have to
|
|
126
|
+
# rearrange the proj_out weight because it contains both out and down
|
|
127
|
+
# projection weights in a single matrix.
|
|
128
|
+
def rearrange_proj_out_weight(
|
|
129
|
+
single_block: FluxSingleTransformerBlock, tp_group_size
|
|
130
|
+
):
|
|
131
|
+
# rowwise
|
|
132
|
+
hidden_dim = single_block.attn.to_q.weight.shape[0]
|
|
133
|
+
requires_grad = single_block.proj_out.weight.requires_grad
|
|
134
|
+
linear2_weight_data = (
|
|
135
|
+
single_block.proj_out.weight.data.T.detach().clone()
|
|
136
|
+
)
|
|
137
|
+
out_weight = linear2_weight_data[:hidden_dim, ...]
|
|
138
|
+
out_weight = rearrange(
|
|
139
|
+
out_weight, "(G D) C -> G D C", G=tp_group_size
|
|
140
|
+
)
|
|
141
|
+
down_weight = linear2_weight_data.data[hidden_dim:, ...]
|
|
142
|
+
down_weight = rearrange(
|
|
143
|
+
down_weight, "(G D) C -> G D C", G=tp_group_size
|
|
144
|
+
)
|
|
145
|
+
new_linear2_weight = torch.cat([out_weight, down_weight], dim=1)
|
|
146
|
+
new_linear2_weight = rearrange(
|
|
147
|
+
new_linear2_weight, "G D C -> (G D) C"
|
|
148
|
+
)
|
|
149
|
+
single_block.proj_out.weight.data.copy_(new_linear2_weight.T)
|
|
150
|
+
single_block.proj_out.weight.requires_grad_(requires_grad)
|
|
151
|
+
|
|
152
|
+
for _, block in transformer.single_transformer_blocks.named_children():
|
|
153
|
+
rearrange_proj_out_weight(block, tp_mesh.size())
|
|
154
|
+
block.attn.heads //= tp_mesh.size()
|
|
155
|
+
layer_plan = {
|
|
156
|
+
"attn.to_q": ColwiseParallel(),
|
|
157
|
+
"attn.to_k": ColwiseParallel(),
|
|
158
|
+
"attn.to_v": ColwiseParallel(),
|
|
159
|
+
"proj_mlp": ColwiseParallel(),
|
|
160
|
+
"proj_out": RowwiseParallel(),
|
|
161
|
+
}
|
|
162
|
+
if getattr(block.norm, "linear", None) is not None:
|
|
163
|
+
layer_plan["norm.linear"] = ColwiseParallel(
|
|
164
|
+
output_layouts=Replicate()
|
|
165
|
+
)
|
|
166
|
+
parallelize_module(
|
|
167
|
+
module=block,
|
|
168
|
+
device_mesh=tp_mesh,
|
|
169
|
+
parallelize_plan=layer_plan,
|
|
170
|
+
)
|
|
171
|
+
return transformer
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
4
|
+
from torch.distributed._tensor import Replicate
|
|
5
|
+
from torch.distributed.tensor.parallel import (
|
|
6
|
+
ColwiseParallel,
|
|
7
|
+
RowwiseParallel,
|
|
8
|
+
parallelize_module,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
from cache_dit.logger import init_logger
|
|
12
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
13
|
+
|
|
14
|
+
from .tp_plan_registers import (
|
|
15
|
+
TensorParallelismPlanner,
|
|
16
|
+
TensorParallelismPlannerRegister,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
logger = init_logger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@TensorParallelismPlannerRegister.register("Kandinsky5")
|
|
23
|
+
class Kandinsky5TensorParallelismPlanner(TensorParallelismPlanner):
|
|
24
|
+
def apply(
|
|
25
|
+
self,
|
|
26
|
+
transformer: torch.nn.Module,
|
|
27
|
+
parallelism_config: ParallelismConfig,
|
|
28
|
+
**kwargs,
|
|
29
|
+
) -> torch.nn.Module:
|
|
30
|
+
assert (
|
|
31
|
+
parallelism_config.tp_size is not None
|
|
32
|
+
and parallelism_config.tp_size > 1
|
|
33
|
+
), (
|
|
34
|
+
"parallel_config.tp_size must be set and greater than 1 for "
|
|
35
|
+
"tensor parallelism"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
device_type = torch.accelerator.current_accelerator().type
|
|
39
|
+
tp_mesh: DeviceMesh = init_device_mesh(
|
|
40
|
+
device_type=device_type,
|
|
41
|
+
mesh_shape=[parallelism_config.tp_size],
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
transformer = self.parallelize_transformer(
|
|
45
|
+
transformer=transformer,
|
|
46
|
+
tp_mesh=tp_mesh,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
return transformer
|
|
50
|
+
|
|
51
|
+
def parallelize_transformer(
|
|
52
|
+
self,
|
|
53
|
+
transformer: nn.Module,
|
|
54
|
+
tp_mesh: DeviceMesh,
|
|
55
|
+
):
|
|
56
|
+
for _, block in transformer.visual_transformer_blocks.named_children():
|
|
57
|
+
block.self_attention.num_heads //= tp_mesh.size()
|
|
58
|
+
block.cross_attention.num_heads //= tp_mesh.size()
|
|
59
|
+
layer_plan = {
|
|
60
|
+
"self_attention.to_query": ColwiseParallel(),
|
|
61
|
+
"self_attention.to_key": ColwiseParallel(),
|
|
62
|
+
"self_attention.to_value": ColwiseParallel(),
|
|
63
|
+
"self_attention.out_layer": RowwiseParallel(),
|
|
64
|
+
"cross_attention.to_query": ColwiseParallel(),
|
|
65
|
+
"cross_attention.to_key": ColwiseParallel(),
|
|
66
|
+
"cross_attention.to_value": ColwiseParallel(),
|
|
67
|
+
"cross_attention.out_layer": RowwiseParallel(),
|
|
68
|
+
"visual_modulation.out_layer": ColwiseParallel(
|
|
69
|
+
output_layouts=Replicate()
|
|
70
|
+
),
|
|
71
|
+
"feed_forward.in_layer": ColwiseParallel(),
|
|
72
|
+
"feed_forward.out_layer": RowwiseParallel(),
|
|
73
|
+
}
|
|
74
|
+
parallelize_module(
|
|
75
|
+
module=block,
|
|
76
|
+
device_mesh=tp_mesh,
|
|
77
|
+
parallelize_plan=layer_plan,
|
|
78
|
+
)
|
|
79
|
+
return transformer
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
4
|
+
from torch.distributed._tensor import Replicate
|
|
5
|
+
from torch.distributed.tensor.parallel import (
|
|
6
|
+
ColwiseParallel,
|
|
7
|
+
RowwiseParallel,
|
|
8
|
+
parallelize_module,
|
|
9
|
+
)
|
|
10
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
11
|
+
from .tp_plan_registers import (
|
|
12
|
+
TensorParallelismPlanner,
|
|
13
|
+
TensorParallelismPlannerRegister,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from cache_dit.logger import init_logger
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@TensorParallelismPlannerRegister.register("QwenImage")
|
|
22
|
+
class QwenImageTensorParallelismPlanner(TensorParallelismPlanner):
|
|
23
|
+
def apply(
|
|
24
|
+
self,
|
|
25
|
+
transformer: torch.nn.Module,
|
|
26
|
+
parallelism_config: ParallelismConfig,
|
|
27
|
+
**kwargs,
|
|
28
|
+
) -> torch.nn.Module:
|
|
29
|
+
assert (
|
|
30
|
+
parallelism_config.tp_size is not None
|
|
31
|
+
and parallelism_config.tp_size > 1
|
|
32
|
+
), (
|
|
33
|
+
"parallel_config.tp_size must be set and greater than 1 for "
|
|
34
|
+
"tensor parallelism"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
device_type = torch.accelerator.current_accelerator().type
|
|
38
|
+
tp_mesh: DeviceMesh = init_device_mesh(
|
|
39
|
+
device_type=device_type,
|
|
40
|
+
mesh_shape=[parallelism_config.tp_size],
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
transformer = self.parallelize_transformer(
|
|
44
|
+
transformer=transformer,
|
|
45
|
+
tp_mesh=tp_mesh,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return transformer
|
|
49
|
+
|
|
50
|
+
def parallelize_transformer(
|
|
51
|
+
self,
|
|
52
|
+
transformer: nn.Module,
|
|
53
|
+
tp_mesh: DeviceMesh,
|
|
54
|
+
):
|
|
55
|
+
for _, block in transformer.transformer_blocks.named_children():
|
|
56
|
+
block.attn.heads //= tp_mesh.size()
|
|
57
|
+
layer_plan = {
|
|
58
|
+
"attn.to_q": ColwiseParallel(),
|
|
59
|
+
"attn.to_k": ColwiseParallel(),
|
|
60
|
+
"attn.to_v": ColwiseParallel(),
|
|
61
|
+
"attn.to_out.0": RowwiseParallel(),
|
|
62
|
+
"img_mod.1": ColwiseParallel(output_layouts=Replicate()),
|
|
63
|
+
"img_mlp.net.0.proj": ColwiseParallel(),
|
|
64
|
+
"img_mlp.net.2": RowwiseParallel(),
|
|
65
|
+
"attn.add_q_proj": ColwiseParallel(),
|
|
66
|
+
"attn.add_k_proj": ColwiseParallel(),
|
|
67
|
+
"attn.add_v_proj": ColwiseParallel(),
|
|
68
|
+
"attn.to_add_out": RowwiseParallel(),
|
|
69
|
+
"txt_mod.1": ColwiseParallel(output_layouts=Replicate()),
|
|
70
|
+
"txt_mlp.net.0.proj": ColwiseParallel(),
|
|
71
|
+
"txt_mlp.net.2": RowwiseParallel(),
|
|
72
|
+
}
|
|
73
|
+
parallelize_module(
|
|
74
|
+
module=block,
|
|
75
|
+
device_mesh=tp_mesh,
|
|
76
|
+
parallelize_plan=layer_plan,
|
|
77
|
+
)
|
|
78
|
+
return transformer
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import logging
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from typing import Dict
|
|
5
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
6
|
+
from cache_dit.logger import init_logger
|
|
7
|
+
|
|
8
|
+
logger = init_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TensorParallelismPlanner:
|
|
12
|
+
# TODO: add `apply_extra` abstract method for extra
|
|
13
|
+
# parallelism kwargs handling
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def apply(
|
|
17
|
+
self,
|
|
18
|
+
transformer: torch.nn.Module,
|
|
19
|
+
parallelism_config: ParallelismConfig,
|
|
20
|
+
**kwargs,
|
|
21
|
+
) -> torch.nn.Module:
|
|
22
|
+
raise NotImplementedError(
|
|
23
|
+
"apply method must be implemented by subclasses"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TensorParallelismPlannerRegister:
|
|
28
|
+
_tp_planner_registry: Dict[str, TensorParallelismPlanner] = {}
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def register(cls, name: str):
|
|
32
|
+
def decorator(planner_cls: type[TensorParallelismPlanner]):
|
|
33
|
+
assert (
|
|
34
|
+
name not in cls._tp_planner_registry
|
|
35
|
+
), f"TensorParallelismPlanner with name {name} is already registered."
|
|
36
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
37
|
+
logger.debug(f"Registering TensorParallelismPlanner: {name}")
|
|
38
|
+
cls._tp_planner_registry[name] = planner_cls
|
|
39
|
+
return planner_cls
|
|
40
|
+
|
|
41
|
+
return decorator
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def get_planner(
|
|
45
|
+
cls, transformer: str | torch.nn.Module
|
|
46
|
+
) -> type[TensorParallelismPlanner]:
|
|
47
|
+
if isinstance(transformer, torch.nn.Module):
|
|
48
|
+
name = transformer.__class__.__name__
|
|
49
|
+
else:
|
|
50
|
+
name = transformer
|
|
51
|
+
planner_cls = None
|
|
52
|
+
for planner_name in cls._tp_planner_registry:
|
|
53
|
+
if name.startswith(planner_name):
|
|
54
|
+
planner_cls = cls._tp_planner_registry.get(planner_name)
|
|
55
|
+
break
|
|
56
|
+
if planner_cls is None:
|
|
57
|
+
raise ValueError(f"No planner registered under name: {name}")
|
|
58
|
+
return planner_cls
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def supported_planners(
|
|
62
|
+
cls,
|
|
63
|
+
) -> tuple[int, list[str]]:
|
|
64
|
+
val_planners = cls._tp_planner_registry.keys()
|
|
65
|
+
return len(val_planners), [p for p in val_planners]
|