cache-dit 1.0.8__py3-none-any.whl → 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/block_adapters/__init__.py +37 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +75 -4
- cache_dit/cache_factory/block_adapters/block_registers.py +44 -17
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +72 -30
- cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
- cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
- cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
- cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
- cache_dit/cache_factory/cache_interface.py +102 -28
- cache_dit/cache_factory/forward_pattern.py +14 -14
- cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -49
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
- cache_dit/parallelism/parallel_backend.py +2 -0
- cache_dit/parallelism/parallel_config.py +10 -3
- cache_dit/parallelism/parallel_interface.py +14 -5
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/utils.py +56 -20
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/METADATA +24 -13
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import einops
|
|
3
|
+
except ImportError:
|
|
4
|
+
raise ImportError(
|
|
5
|
+
"Metrics functionality requires the 'parallelism' extra dependencies. "
|
|
6
|
+
"Install with:\npip install cache-dit[parallelism]"
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
12
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
13
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
14
|
+
from cache_dit.logger import init_logger
|
|
15
|
+
from .tp_planners import *
|
|
16
|
+
|
|
17
|
+
logger = init_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def maybe_enable_tensor_parallelism(
|
|
21
|
+
transformer: torch.nn.Module | ModelMixin,
|
|
22
|
+
parallelism_config: Optional[ParallelismConfig],
|
|
23
|
+
) -> torch.nn.Module:
|
|
24
|
+
assert isinstance(transformer, torch.nn.Module), (
|
|
25
|
+
"transformer must be an instance of torch.nn.Module, "
|
|
26
|
+
f"but got {type(transformer)}"
|
|
27
|
+
)
|
|
28
|
+
assert isinstance(transformer, ModelMixin), (
|
|
29
|
+
"transformer must be an instance of diffusers' ModelMixin, "
|
|
30
|
+
f"but got {type(transformer)}"
|
|
31
|
+
)
|
|
32
|
+
if parallelism_config is None:
|
|
33
|
+
return transformer
|
|
34
|
+
|
|
35
|
+
assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
|
|
36
|
+
"parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
|
|
37
|
+
f"but got {parallelism_config.backend}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
extra_parallel_kwargs = {}
|
|
41
|
+
if parallelism_config.parallel_kwargs is not None:
|
|
42
|
+
extra_parallel_kwargs = parallelism_config.parallel_kwargs
|
|
43
|
+
|
|
44
|
+
return TensorParallelismPlannerRegister.get_planner(transformer)().apply(
|
|
45
|
+
transformer=transformer,
|
|
46
|
+
parallelism_config=parallelism_config,
|
|
47
|
+
**extra_parallel_kwargs,
|
|
48
|
+
)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from diffusers.models.transformers.transformer_flux import (
|
|
3
|
+
FluxSingleTransformerBlock,
|
|
4
|
+
)
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
8
|
+
from torch.distributed._tensor import Replicate
|
|
9
|
+
from torch.distributed.tensor.parallel import (
|
|
10
|
+
ColwiseParallel,
|
|
11
|
+
RowwiseParallel,
|
|
12
|
+
parallelize_module,
|
|
13
|
+
)
|
|
14
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
15
|
+
from .tp_plan_registers import (
|
|
16
|
+
TensorParallelismPlanner,
|
|
17
|
+
TensorParallelismPlannerRegister,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from cache_dit.logger import init_logger
|
|
21
|
+
|
|
22
|
+
logger = init_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@TensorParallelismPlannerRegister.register("Flux")
|
|
26
|
+
class FluxTensorParallelismPlanner(TensorParallelismPlanner):
|
|
27
|
+
def apply(
|
|
28
|
+
self,
|
|
29
|
+
transformer: torch.nn.Module,
|
|
30
|
+
parallelism_config: ParallelismConfig,
|
|
31
|
+
**kwargs,
|
|
32
|
+
) -> torch.nn.Module:
|
|
33
|
+
assert (
|
|
34
|
+
parallelism_config.tp_size is not None
|
|
35
|
+
and parallelism_config.tp_size > 1
|
|
36
|
+
), (
|
|
37
|
+
"parallel_config.tp_size must be set and greater than 1 for "
|
|
38
|
+
"tensor parallelism"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
device_type = torch.accelerator.current_accelerator().type
|
|
42
|
+
tp_mesh: DeviceMesh = init_device_mesh(
|
|
43
|
+
device_type=device_type,
|
|
44
|
+
mesh_shape=[parallelism_config.tp_size],
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
transformer = self.parallelize_transformer(
|
|
48
|
+
transformer=transformer,
|
|
49
|
+
tp_mesh=tp_mesh,
|
|
50
|
+
)
|
|
51
|
+
# TODO: Parallelize t5 text encoder via `apply_extra`
|
|
52
|
+
# abstract method and `extra_parallel_kwargs` ?
|
|
53
|
+
|
|
54
|
+
return transformer
|
|
55
|
+
|
|
56
|
+
def parallelize_t5(
|
|
57
|
+
self,
|
|
58
|
+
text_encoder: nn.Module,
|
|
59
|
+
tp_mesh: DeviceMesh,
|
|
60
|
+
):
|
|
61
|
+
for i, block in enumerate(text_encoder.encoder.block):
|
|
62
|
+
block.layer[0].SelfAttention.n_heads //= tp_mesh.size()
|
|
63
|
+
block.layer[0].SelfAttention.inner_dim //= tp_mesh.size()
|
|
64
|
+
layer_plan = {
|
|
65
|
+
"layer.0.SelfAttention.q": ColwiseParallel(),
|
|
66
|
+
"layer.0.SelfAttention.k": ColwiseParallel(),
|
|
67
|
+
"layer.0.SelfAttention.v": ColwiseParallel(),
|
|
68
|
+
"layer.0.SelfAttention.o": RowwiseParallel(),
|
|
69
|
+
"layer.1.DenseReluDense.wi_0": ColwiseParallel(),
|
|
70
|
+
"layer.1.DenseReluDense.wi_1": ColwiseParallel(),
|
|
71
|
+
"layer.1.DenseReluDense.wo": RowwiseParallel(),
|
|
72
|
+
}
|
|
73
|
+
if i == 0:
|
|
74
|
+
layer_plan["layer.0.SelfAttention.relative_attention_bias"] = (
|
|
75
|
+
ColwiseParallel()
|
|
76
|
+
)
|
|
77
|
+
parallelize_module(
|
|
78
|
+
module=block,
|
|
79
|
+
device_mesh=tp_mesh,
|
|
80
|
+
parallelize_plan=layer_plan,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return text_encoder
|
|
84
|
+
|
|
85
|
+
def parallelize_transformer(
|
|
86
|
+
self,
|
|
87
|
+
transformer: nn.Module,
|
|
88
|
+
tp_mesh: DeviceMesh,
|
|
89
|
+
):
|
|
90
|
+
for _, block in transformer.transformer_blocks.named_children():
|
|
91
|
+
block.attn.heads //= tp_mesh.size()
|
|
92
|
+
layer_plan = {
|
|
93
|
+
"attn.to_q": ColwiseParallel(),
|
|
94
|
+
"attn.to_k": ColwiseParallel(),
|
|
95
|
+
"attn.to_v": ColwiseParallel(),
|
|
96
|
+
"attn.to_out.0": RowwiseParallel(),
|
|
97
|
+
"norm1.linear": ColwiseParallel(output_layouts=Replicate()),
|
|
98
|
+
"ff.net.0.proj": ColwiseParallel(),
|
|
99
|
+
"ff.net.2": RowwiseParallel(),
|
|
100
|
+
"attn.add_q_proj": ColwiseParallel(),
|
|
101
|
+
"attn.add_k_proj": ColwiseParallel(),
|
|
102
|
+
"attn.add_v_proj": ColwiseParallel(),
|
|
103
|
+
"attn.to_add_out": RowwiseParallel(),
|
|
104
|
+
"norm1_context.linear": ColwiseParallel(
|
|
105
|
+
output_layouts=Replicate()
|
|
106
|
+
),
|
|
107
|
+
"ff_context.net.0.proj": ColwiseParallel(),
|
|
108
|
+
"ff_context.net.2": RowwiseParallel(),
|
|
109
|
+
}
|
|
110
|
+
parallelize_module(
|
|
111
|
+
module=block,
|
|
112
|
+
device_mesh=tp_mesh,
|
|
113
|
+
parallelize_plan=layer_plan,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# NOTE: special handling for FluxSingleTransformerBlock, we have to
|
|
117
|
+
# rearrange the proj_out weight because it contains both out and down
|
|
118
|
+
# projection weights in a single matrix.
|
|
119
|
+
def rearrange_proj_out_weight(
|
|
120
|
+
single_block: FluxSingleTransformerBlock, tp_group_size
|
|
121
|
+
):
|
|
122
|
+
# rowwise
|
|
123
|
+
hidden_dim = 3072
|
|
124
|
+
requires_grad = single_block.proj_out.weight.requires_grad
|
|
125
|
+
linear2_weight_data = (
|
|
126
|
+
single_block.proj_out.weight.data.T.detach().clone()
|
|
127
|
+
)
|
|
128
|
+
out_weight = linear2_weight_data[:hidden_dim, ...]
|
|
129
|
+
out_weight = rearrange(
|
|
130
|
+
out_weight, "(G D) C -> G D C", G=tp_group_size
|
|
131
|
+
)
|
|
132
|
+
down_weight = linear2_weight_data.data[hidden_dim:, ...]
|
|
133
|
+
down_weight = rearrange(
|
|
134
|
+
down_weight, "(G D) C -> G D C", G=tp_group_size
|
|
135
|
+
)
|
|
136
|
+
new_linear2_weight = torch.cat([out_weight, down_weight], dim=1)
|
|
137
|
+
new_linear2_weight = rearrange(
|
|
138
|
+
new_linear2_weight, "G D C -> (G D) C"
|
|
139
|
+
)
|
|
140
|
+
single_block.proj_out.weight.data.copy_(new_linear2_weight.T)
|
|
141
|
+
single_block.proj_out.weight.requires_grad_(requires_grad)
|
|
142
|
+
|
|
143
|
+
for _, block in transformer.single_transformer_blocks.named_children():
|
|
144
|
+
rearrange_proj_out_weight(block, tp_mesh.size())
|
|
145
|
+
block.attn.heads //= tp_mesh.size()
|
|
146
|
+
layer_plan = {
|
|
147
|
+
"attn.to_q": ColwiseParallel(),
|
|
148
|
+
"attn.to_k": ColwiseParallel(),
|
|
149
|
+
"attn.to_v": ColwiseParallel(),
|
|
150
|
+
"proj_mlp": ColwiseParallel(),
|
|
151
|
+
"proj_out": RowwiseParallel(),
|
|
152
|
+
"norm.linear": ColwiseParallel(output_layouts=Replicate()),
|
|
153
|
+
}
|
|
154
|
+
parallelize_module(
|
|
155
|
+
module=block,
|
|
156
|
+
device_mesh=tp_mesh,
|
|
157
|
+
parallelize_plan=layer_plan,
|
|
158
|
+
)
|
|
159
|
+
return transformer
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
4
|
+
from torch.distributed._tensor import Replicate
|
|
5
|
+
from torch.distributed.tensor.parallel import (
|
|
6
|
+
ColwiseParallel,
|
|
7
|
+
RowwiseParallel,
|
|
8
|
+
parallelize_module,
|
|
9
|
+
)
|
|
10
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
11
|
+
from .tp_plan_registers import (
|
|
12
|
+
TensorParallelismPlanner,
|
|
13
|
+
TensorParallelismPlannerRegister,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from cache_dit.logger import init_logger
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@TensorParallelismPlannerRegister.register("QwenImage")
|
|
22
|
+
class QwenImageTensorParallelismPlanner(TensorParallelismPlanner):
|
|
23
|
+
def apply(
|
|
24
|
+
self,
|
|
25
|
+
transformer: torch.nn.Module,
|
|
26
|
+
parallelism_config: ParallelismConfig,
|
|
27
|
+
**kwargs,
|
|
28
|
+
) -> torch.nn.Module:
|
|
29
|
+
assert (
|
|
30
|
+
parallelism_config.tp_size is not None
|
|
31
|
+
and parallelism_config.tp_size > 1
|
|
32
|
+
), (
|
|
33
|
+
"parallel_config.tp_size must be set and greater than 1 for "
|
|
34
|
+
"tensor parallelism"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
device_type = torch.accelerator.current_accelerator().type
|
|
38
|
+
tp_mesh: DeviceMesh = init_device_mesh(
|
|
39
|
+
device_type=device_type,
|
|
40
|
+
mesh_shape=[parallelism_config.tp_size],
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
transformer = self.parallelize_transformer(
|
|
44
|
+
transformer=transformer,
|
|
45
|
+
tp_mesh=tp_mesh,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return transformer
|
|
49
|
+
|
|
50
|
+
def parallelize_transformer(
|
|
51
|
+
self,
|
|
52
|
+
transformer: nn.Module,
|
|
53
|
+
tp_mesh: DeviceMesh,
|
|
54
|
+
):
|
|
55
|
+
for _, block in transformer.transformer_blocks.named_children():
|
|
56
|
+
block.attn.heads //= tp_mesh.size()
|
|
57
|
+
layer_plan = {
|
|
58
|
+
"attn.to_q": ColwiseParallel(),
|
|
59
|
+
"attn.to_k": ColwiseParallel(),
|
|
60
|
+
"attn.to_v": ColwiseParallel(),
|
|
61
|
+
"attn.to_out.0": RowwiseParallel(),
|
|
62
|
+
"img_mod.1": ColwiseParallel(output_layouts=Replicate()),
|
|
63
|
+
"img_mlp.net.0.proj": ColwiseParallel(),
|
|
64
|
+
"img_mlp.net.2": RowwiseParallel(),
|
|
65
|
+
"attn.add_q_proj": ColwiseParallel(),
|
|
66
|
+
"attn.add_k_proj": ColwiseParallel(),
|
|
67
|
+
"attn.add_v_proj": ColwiseParallel(),
|
|
68
|
+
"attn.to_add_out": RowwiseParallel(),
|
|
69
|
+
"txt_mod.1": ColwiseParallel(output_layouts=Replicate()),
|
|
70
|
+
"txt_mlp.net.0.proj": ColwiseParallel(),
|
|
71
|
+
"txt_mlp.net.2": RowwiseParallel(),
|
|
72
|
+
}
|
|
73
|
+
parallelize_module(
|
|
74
|
+
module=block,
|
|
75
|
+
device_mesh=tp_mesh,
|
|
76
|
+
parallelize_plan=layer_plan,
|
|
77
|
+
)
|
|
78
|
+
return transformer
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import logging
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from typing import Dict
|
|
5
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
6
|
+
from cache_dit.logger import init_logger
|
|
7
|
+
|
|
8
|
+
logger = init_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TensorParallelismPlanner:
|
|
12
|
+
# TODO: add `apply_extra` abstract method for extra
|
|
13
|
+
# parallelism kwargs handling
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def apply(
|
|
17
|
+
self,
|
|
18
|
+
transformer: torch.nn.Module,
|
|
19
|
+
parallelism_config: ParallelismConfig,
|
|
20
|
+
**kwargs,
|
|
21
|
+
) -> torch.nn.Module:
|
|
22
|
+
raise NotImplementedError(
|
|
23
|
+
"apply method must be implemented by subclasses"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TensorParallelismPlannerRegister:
|
|
28
|
+
_tp_planner_registry: Dict[str, TensorParallelismPlanner] = {}
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def register(cls, name: str):
|
|
32
|
+
def decorator(planner_cls: type[TensorParallelismPlanner]):
|
|
33
|
+
assert (
|
|
34
|
+
name not in cls._tp_planner_registry
|
|
35
|
+
), f"TensorParallelismPlanner with name {name} is already registered."
|
|
36
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
37
|
+
logger.debug(f"Registering TensorParallelismPlanner: {name}")
|
|
38
|
+
cls._tp_planner_registry[name] = planner_cls
|
|
39
|
+
return planner_cls
|
|
40
|
+
|
|
41
|
+
return decorator
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def get_planner(
|
|
45
|
+
cls, transformer: str | torch.nn.Module
|
|
46
|
+
) -> type[TensorParallelismPlanner]:
|
|
47
|
+
if isinstance(transformer, torch.nn.Module):
|
|
48
|
+
name = transformer.__class__.__name__
|
|
49
|
+
else:
|
|
50
|
+
name = transformer
|
|
51
|
+
planner_cls = None
|
|
52
|
+
for planner_name in cls._tp_planner_registry:
|
|
53
|
+
if name.startswith(planner_name):
|
|
54
|
+
planner_cls = cls._tp_planner_registry.get(planner_name)
|
|
55
|
+
break
|
|
56
|
+
if planner_cls is None:
|
|
57
|
+
raise ValueError(f"No planner registered under name: {name}")
|
|
58
|
+
return planner_cls
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
6
|
+
from torch.distributed.tensor.parallel import (
|
|
7
|
+
ColwiseParallel,
|
|
8
|
+
RowwiseParallel,
|
|
9
|
+
parallelize_module,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
14
|
+
|
|
15
|
+
from .tp_plan_registers import (
|
|
16
|
+
TensorParallelismPlanner,
|
|
17
|
+
TensorParallelismPlannerRegister,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DistributedRMSNorm(nn.Module):
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
tp_mesh: DeviceMesh,
|
|
27
|
+
normalized_shape: Union[int, list[int], torch.Size],
|
|
28
|
+
eps: Optional[float],
|
|
29
|
+
elementwise_affine: bool,
|
|
30
|
+
weight: torch.nn.parameter.Parameter,
|
|
31
|
+
):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.tp_mesh = tp_mesh
|
|
34
|
+
self.elementwise_affine = elementwise_affine
|
|
35
|
+
self.normalized_shape = normalized_shape
|
|
36
|
+
self.eps = eps
|
|
37
|
+
if self.elementwise_affine:
|
|
38
|
+
assert weight is not None
|
|
39
|
+
self.weight = weight
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_rmsnorm(cls, tp_mesh: DeviceMesh, rmsnorm: nn.RMSNorm):
|
|
43
|
+
if not isinstance(rmsnorm, int):
|
|
44
|
+
assert len(rmsnorm.normalized_shape) == 1
|
|
45
|
+
|
|
46
|
+
if rmsnorm.weight is not None:
|
|
47
|
+
tp_size = tp_mesh.get_group().size()
|
|
48
|
+
tp_rank = tp_mesh.get_group().rank()
|
|
49
|
+
weight = rmsnorm.weight.chunk(tp_size, dim=0)[tp_rank]
|
|
50
|
+
else:
|
|
51
|
+
weight = None
|
|
52
|
+
norm = cls(
|
|
53
|
+
tp_mesh=tp_mesh,
|
|
54
|
+
normalized_shape=rmsnorm.normalized_shape,
|
|
55
|
+
eps=rmsnorm.eps,
|
|
56
|
+
elementwise_affine=rmsnorm.elementwise_affine,
|
|
57
|
+
weight=weight,
|
|
58
|
+
)
|
|
59
|
+
return norm
|
|
60
|
+
|
|
61
|
+
def forward(self, x):
|
|
62
|
+
if self.elementwise_affine:
|
|
63
|
+
assert x.shape[-1] == self.weight.shape[0]
|
|
64
|
+
mean_square = torch.mean(x * x, dim=-1, keepdim=True)
|
|
65
|
+
torch.distributed.all_reduce(
|
|
66
|
+
mean_square,
|
|
67
|
+
op=torch.distributed.ReduceOp.AVG,
|
|
68
|
+
group=self.tp_mesh.get_group(),
|
|
69
|
+
)
|
|
70
|
+
root_mean_square = torch.sqrt(mean_square + self.eps)
|
|
71
|
+
x_normed = x / root_mean_square
|
|
72
|
+
if self.elementwise_affine:
|
|
73
|
+
x_normed = x_normed * self.weight.to(device=x.device)
|
|
74
|
+
assert x_normed.device.type != "cpu"
|
|
75
|
+
return x_normed
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@TensorParallelismPlannerRegister.register("Wan")
|
|
79
|
+
class WanTensorParallelismPlanner(TensorParallelismPlanner):
|
|
80
|
+
def apply(
|
|
81
|
+
self,
|
|
82
|
+
transformer: torch.nn.Module,
|
|
83
|
+
parallelism_config: ParallelismConfig,
|
|
84
|
+
**kwargs,
|
|
85
|
+
) -> torch.nn.Module:
|
|
86
|
+
assert (
|
|
87
|
+
parallelism_config.tp_size is not None
|
|
88
|
+
and parallelism_config.tp_size > 1
|
|
89
|
+
), (
|
|
90
|
+
"parallel_config.tp_size must be set and greater than 1 for "
|
|
91
|
+
"tensor parallelism"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
device_type = torch.accelerator.current_accelerator().type
|
|
95
|
+
tp_mesh: DeviceMesh = init_device_mesh(
|
|
96
|
+
device_type=device_type,
|
|
97
|
+
mesh_shape=[parallelism_config.tp_size],
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
transformer = self.parallelize_transformer(
|
|
101
|
+
transformer=transformer,
|
|
102
|
+
tp_mesh=tp_mesh,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return transformer
|
|
106
|
+
|
|
107
|
+
def parallelize_transformer(
|
|
108
|
+
self,
|
|
109
|
+
transformer: nn.Module,
|
|
110
|
+
tp_mesh: DeviceMesh,
|
|
111
|
+
):
|
|
112
|
+
for _, block in transformer.blocks.named_children():
|
|
113
|
+
block.attn1.heads //= tp_mesh.size()
|
|
114
|
+
block.attn2.heads //= tp_mesh.size()
|
|
115
|
+
layer_plan = {
|
|
116
|
+
"attn1.to_q": ColwiseParallel(),
|
|
117
|
+
"attn1.to_k": ColwiseParallel(),
|
|
118
|
+
"attn1.to_v": ColwiseParallel(),
|
|
119
|
+
"attn1.to_out.0": RowwiseParallel(),
|
|
120
|
+
"attn2.to_q": ColwiseParallel(),
|
|
121
|
+
"attn2.to_k": ColwiseParallel(),
|
|
122
|
+
"attn2.to_v": ColwiseParallel(),
|
|
123
|
+
"attn2.to_out.0": RowwiseParallel(),
|
|
124
|
+
"ffn.net.0.proj": ColwiseParallel(),
|
|
125
|
+
"ffn.net.2": RowwiseParallel(),
|
|
126
|
+
}
|
|
127
|
+
if getattr(block.attn2, "add_k_proj", None):
|
|
128
|
+
layer_plan["attn2.add_k_proj"] = ColwiseParallel()
|
|
129
|
+
if getattr(block.attn2, "add_v_proj", None):
|
|
130
|
+
layer_plan["attn2.add_v_proj"] = ColwiseParallel()
|
|
131
|
+
parallelize_module(
|
|
132
|
+
module=block,
|
|
133
|
+
device_mesh=tp_mesh,
|
|
134
|
+
parallelize_plan=layer_plan,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
block.attn1.norm_q = DistributedRMSNorm.from_rmsnorm(
|
|
138
|
+
tp_mesh, block.attn1.norm_q
|
|
139
|
+
)
|
|
140
|
+
block.attn1.norm_k = DistributedRMSNorm.from_rmsnorm(
|
|
141
|
+
tp_mesh, block.attn1.norm_k
|
|
142
|
+
)
|
|
143
|
+
block.attn2.norm_q = DistributedRMSNorm.from_rmsnorm(
|
|
144
|
+
tp_mesh, block.attn2.norm_q
|
|
145
|
+
)
|
|
146
|
+
block.attn2.norm_k = DistributedRMSNorm.from_rmsnorm(
|
|
147
|
+
tp_mesh, block.attn2.norm_k
|
|
148
|
+
)
|
|
149
|
+
if getattr(block.attn2, "norm_added_k", None):
|
|
150
|
+
block.attn2.norm_added_k = DistributedRMSNorm.from_rmsnorm(
|
|
151
|
+
tp_mesh, block.attn2.norm_added_k
|
|
152
|
+
)
|
|
153
|
+
return transformer
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# NOTE: must import all planner classes to register them
|
|
2
|
+
from .tp_plan_registers import TensorParallelismPlannerRegister
|
|
3
|
+
from .tp_plan_flux import FluxTensorParallelismPlanner
|
|
4
|
+
from .tp_plan_qwen_image import QwenImageTensorParallelismPlanner
|
|
5
|
+
from .tp_plan_wan import WanTensorParallelismPlanner
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"TensorParallelismPlannerRegister",
|
|
9
|
+
"FluxTensorParallelismPlanner",
|
|
10
|
+
"QwenImageTensorParallelismPlanner",
|
|
11
|
+
"WanTensorParallelismPlanner",
|
|
12
|
+
]
|
|
@@ -8,6 +8,8 @@ class ParallelismBackend(Enum):
|
|
|
8
8
|
|
|
9
9
|
@classmethod
|
|
10
10
|
def is_supported(cls, backend: "ParallelismBackend") -> bool:
|
|
11
|
+
if backend in [cls.NATIVE_PYTORCH]:
|
|
12
|
+
return True
|
|
11
13
|
# Now, only Native_Diffuser backend is supported
|
|
12
14
|
if backend in [cls.NATIVE_DIFFUSER]:
|
|
13
15
|
try:
|
|
@@ -23,8 +23,8 @@ class ParallelismConfig:
|
|
|
23
23
|
tp_size: int = None
|
|
24
24
|
# parallel_kwargs (`dict`, *optional*):
|
|
25
25
|
# Additional kwargs for parallelism backends. For example, for
|
|
26
|
-
# NATIVE_DIFFUSER backend, it can include `cp_plan` and
|
|
27
|
-
# arguments for `Context Parallelism`.
|
|
26
|
+
# NATIVE_DIFFUSER backend, it can include `cp_plan` and
|
|
27
|
+
# `attention_backend` arguments for `Context Parallelism`.
|
|
28
28
|
parallel_kwargs: Optional[Dict[str, Any]] = dataclasses.field(
|
|
29
29
|
default_factory=dict
|
|
30
30
|
)
|
|
@@ -34,7 +34,14 @@ class ParallelismConfig:
|
|
|
34
34
|
f"Parallel backend {self.backend} is not supported. "
|
|
35
35
|
f"Please make sure the required packages are installed."
|
|
36
36
|
)
|
|
37
|
-
|
|
37
|
+
|
|
38
|
+
if self.tp_size is not None and self.tp_size > 1:
|
|
39
|
+
assert (
|
|
40
|
+
self.ulysses_size is None or self.ulysses_size == 1
|
|
41
|
+
), "Tensor parallelism plus Ulysses parallelism is not supported right now."
|
|
42
|
+
assert (
|
|
43
|
+
self.ring_size is None or self.ring_size == 1
|
|
44
|
+
), "Tensor parallelism plus Ring parallelism is not supported right now."
|
|
38
45
|
|
|
39
46
|
def strify(self, details: bool = False) -> str:
|
|
40
47
|
if details:
|
|
@@ -24,12 +24,17 @@ def enable_parallelism(
|
|
|
24
24
|
if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
|
|
25
25
|
from cache_dit.parallelism.backends.native_diffusers import (
|
|
26
26
|
maybe_enable_parallelism,
|
|
27
|
-
native_diffusers_parallelism_available,
|
|
28
27
|
)
|
|
29
28
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
29
|
+
transformer = maybe_enable_parallelism(
|
|
30
|
+
transformer,
|
|
31
|
+
parallelism_config,
|
|
32
|
+
)
|
|
33
|
+
elif parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH:
|
|
34
|
+
from cache_dit.parallelism.backends.native_pytorch import (
|
|
35
|
+
maybe_enable_parallelism,
|
|
36
|
+
)
|
|
37
|
+
|
|
33
38
|
transformer = maybe_enable_parallelism(
|
|
34
39
|
transformer,
|
|
35
40
|
parallelism_config,
|
|
@@ -40,8 +45,12 @@ def enable_parallelism(
|
|
|
40
45
|
)
|
|
41
46
|
|
|
42
47
|
transformer._is_parallelized = True # type: ignore[attr-defined]
|
|
48
|
+
# Use `parallelism` not `parallel` to avoid name conflict with diffusers.
|
|
43
49
|
transformer._parallelism_config = parallelism_config # type: ignore[attr-defined]
|
|
44
|
-
logger.info(
|
|
50
|
+
logger.info(
|
|
51
|
+
f"Enabled parallelism: {parallelism_config.strify(True)}, "
|
|
52
|
+
f"transformer id:{id(transformer)}"
|
|
53
|
+
)
|
|
45
54
|
return transformer
|
|
46
55
|
|
|
47
56
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .torchao import quantize_ao
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .quantize_ao import quantize_ao
|
|
@@ -9,7 +9,7 @@ logger = init_logger(__name__)
|
|
|
9
9
|
|
|
10
10
|
def quantize_ao(
|
|
11
11
|
module: torch.nn.Module,
|
|
12
|
-
quant_type: str = "
|
|
12
|
+
quant_type: str = "float8_weight_only",
|
|
13
13
|
exclude_layers: List[str] = [
|
|
14
14
|
"embedder",
|
|
15
15
|
"embed",
|
|
@@ -24,6 +24,18 @@ def quantize_ao(
|
|
|
24
24
|
# set `exclude_layers` as `[]` if you don't want this behavior.
|
|
25
25
|
assert isinstance(module, torch.nn.Module)
|
|
26
26
|
|
|
27
|
+
alias_map = {
|
|
28
|
+
"float8": "fp8_w8a8_dq",
|
|
29
|
+
"float8_weight_only": "fp8_w8a16_wo",
|
|
30
|
+
"int8": "int8_w8a8_dq",
|
|
31
|
+
"int8_weight_only": "int8_w8a16_wo",
|
|
32
|
+
"int4": "int4_w4a8_dq",
|
|
33
|
+
"int4_w4a4": "int4_w4a4_dq",
|
|
34
|
+
"int4_weight_only": "int4_w4a16_wo",
|
|
35
|
+
}
|
|
36
|
+
if quant_type.lower() in alias_map:
|
|
37
|
+
quant_type = alias_map[quant_type.lower()]
|
|
38
|
+
|
|
27
39
|
quant_type = quant_type.lower()
|
|
28
40
|
assert quant_type in (
|
|
29
41
|
"fp8_w8a8_dq",
|
|
@@ -183,7 +195,11 @@ def quantize_ao(
|
|
|
183
195
|
device=kwargs.get("device", None),
|
|
184
196
|
)
|
|
185
197
|
|
|
186
|
-
|
|
198
|
+
maybe_empty_cache()
|
|
199
|
+
|
|
200
|
+
alias_map_rev = {v: k for k, v in alias_map.items()}
|
|
201
|
+
if quant_type in alias_map_rev:
|
|
202
|
+
quant_type = alias_map_rev[quant_type]
|
|
187
203
|
|
|
188
204
|
logger.info(
|
|
189
205
|
f"Quantized Module: {module.__class__.__name__:>5}\n"
|
|
@@ -199,10 +215,13 @@ def quantize_ao(
|
|
|
199
215
|
return module
|
|
200
216
|
|
|
201
217
|
|
|
202
|
-
def
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
218
|
+
def maybe_empty_cache():
|
|
219
|
+
try:
|
|
220
|
+
time.sleep(1)
|
|
221
|
+
gc.collect()
|
|
222
|
+
torch.cuda.empty_cache()
|
|
223
|
+
time.sleep(1)
|
|
224
|
+
gc.collect()
|
|
225
|
+
torch.cuda.empty_cache()
|
|
226
|
+
except Exception:
|
|
227
|
+
pass
|
|
File without changes
|
|
File without changes
|